aboutsummaryrefslogtreecommitdiff
path: root/lib/Target/ARM/ARMParallelDSP.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Target/ARM/ARMParallelDSP.cpp')
-rw-r--r--lib/Target/ARM/ARMParallelDSP.cpp675
1 files changed, 323 insertions, 352 deletions
diff --git a/lib/Target/ARM/ARMParallelDSP.cpp b/lib/Target/ARM/ARMParallelDSP.cpp
index 5389d09bf7d7..ae5657a0a2c1 100644
--- a/lib/Target/ARM/ARMParallelDSP.cpp
+++ b/lib/Target/ARM/ARMParallelDSP.cpp
@@ -1,4 +1,4 @@
-//===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
+//===- ARMParallelDSP.cpp - Parallel DSP Pass -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -18,13 +18,11 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/LoopAccessAnalysis.h"
-#include "llvm/Analysis/LoopPass.h"
-#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/OrderedBasicBlock.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/NoFolder.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
-#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Pass.h"
#include "llvm/PassRegistry.h"
#include "llvm/PassSupport.h"
@@ -45,54 +43,39 @@ static cl::opt<bool>
DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
cl::desc("Disable the ARM Parallel DSP pass"));
+static cl::opt<unsigned>
+NumLoadLimit("arm-parallel-dsp-load-limit", cl::Hidden, cl::init(16),
+ cl::desc("Limit the number of loads analysed"));
+
namespace {
- struct OpChain;
- struct BinOpChain;
+ struct MulCandidate;
class Reduction;
- using OpChainList = SmallVector<std::unique_ptr<OpChain>, 8>;
- using ReductionList = SmallVector<Reduction, 8>;
- using ValueList = SmallVector<Value*, 8>;
- using MemInstList = SmallVector<LoadInst*, 8>;
- using PMACPair = std::pair<BinOpChain*,BinOpChain*>;
- using PMACPairList = SmallVector<PMACPair, 8>;
- using Instructions = SmallVector<Instruction*,16>;
- using MemLocList = SmallVector<MemoryLocation, 4>;
+ using MulCandList = SmallVector<std::unique_ptr<MulCandidate>, 8>;
+ using MemInstList = SmallVectorImpl<LoadInst*>;
+ using MulPairList = SmallVector<std::pair<MulCandidate*, MulCandidate*>, 8>;
- struct OpChain {
+ // 'MulCandidate' holds the multiplication instructions that are candidates
+ // for parallel execution.
+ struct MulCandidate {
Instruction *Root;
- ValueList AllValues;
- MemInstList VecLd; // List of all load instructions.
- MemInstList Loads;
+ Value* LHS;
+ Value* RHS;
+ bool Exchange = false;
bool ReadOnly = true;
+ bool Paired = false;
+ SmallVector<LoadInst*, 2> VecLd; // Container for loads to widen.
- OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { }
- virtual ~OpChain() = default;
+ MulCandidate(Instruction *I, Value *lhs, Value *rhs) :
+ Root(I), LHS(lhs), RHS(rhs) { }
- void PopulateLoads() {
- for (auto *V : AllValues) {
- if (auto *Ld = dyn_cast<LoadInst>(V))
- Loads.push_back(Ld);
- }
+ bool HasTwoLoadInputs() const {
+ return isa<LoadInst>(LHS) && isa<LoadInst>(RHS);
}
- unsigned size() const { return AllValues.size(); }
- };
-
- // 'BinOpChain' holds the multiplication instructions that are candidates
- // for parallel execution.
- struct BinOpChain : public OpChain {
- ValueList LHS; // List of all (narrow) left hand operands.
- ValueList RHS; // List of all (narrow) right hand operands.
- bool Exchange = false;
-
- BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
- OpChain(I, lhs), LHS(lhs), RHS(rhs) {
- for (auto *V : RHS)
- AllValues.push_back(V);
- }
-
- bool AreSymmetrical(BinOpChain *Other);
+ LoadInst *getBaseLoad() const {
+ return VecLd.front();
+ }
};
/// Represent a sequence of multiply-accumulate operations with the aim to
@@ -100,9 +83,9 @@ namespace {
class Reduction {
Instruction *Root = nullptr;
Value *Acc = nullptr;
- OpChainList Muls;
- PMACPairList MulPairs;
- SmallPtrSet<Instruction*, 4> Adds;
+ MulCandList Muls;
+ MulPairList MulPairs;
+ SetVector<Instruction*> Adds;
public:
Reduction() = delete;
@@ -112,10 +95,35 @@ namespace {
/// Record an Add instruction that is a part of the this reduction.
void InsertAdd(Instruction *I) { Adds.insert(I); }
- /// Record a BinOpChain, rooted at a Mul instruction, that is a part of
- /// this reduction.
- void InsertMul(Instruction *I, ValueList &LHS, ValueList &RHS) {
- Muls.push_back(make_unique<BinOpChain>(I, LHS, RHS));
+ /// Create MulCandidates, each rooted at a Mul instruction, that is a part
+ /// of this reduction.
+ void InsertMuls() {
+ auto GetMulOperand = [](Value *V) -> Instruction* {
+ if (auto *SExt = dyn_cast<SExtInst>(V)) {
+ if (auto *I = dyn_cast<Instruction>(SExt->getOperand(0)))
+ if (I->getOpcode() == Instruction::Mul)
+ return I;
+ } else if (auto *I = dyn_cast<Instruction>(V)) {
+ if (I->getOpcode() == Instruction::Mul)
+ return I;
+ }
+ return nullptr;
+ };
+
+ auto InsertMul = [this](Instruction *I) {
+ Value *LHS = cast<Instruction>(I->getOperand(0))->getOperand(0);
+ Value *RHS = cast<Instruction>(I->getOperand(1))->getOperand(0);
+ Muls.push_back(std::make_unique<MulCandidate>(I, LHS, RHS));
+ };
+
+ for (auto *Add : Adds) {
+ if (Add == Acc)
+ continue;
+ if (auto *Mul = GetMulOperand(Add->getOperand(0)))
+ InsertMul(Mul);
+ if (auto *Mul = GetMulOperand(Add->getOperand(1)))
+ InsertMul(Mul);
+ }
}
/// Add the incoming accumulator value, returns true if a value had not
@@ -128,9 +136,17 @@ namespace {
return true;
}
- /// Set two BinOpChains, rooted at muls, that can be executed as a single
+ /// Set two MulCandidates, rooted at muls, that can be executed as a single
/// parallel operation.
- void AddMulPair(BinOpChain *Mul0, BinOpChain *Mul1) {
+ void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1,
+ bool Exchange = false) {
+ LLVM_DEBUG(dbgs() << "Pairing:\n"
+ << *Mul0->Root << "\n"
+ << *Mul1->Root << "\n");
+ Mul0->Paired = true;
+ Mul1->Paired = true;
+ if (Exchange)
+ Mul1->Exchange = true;
MulPairs.push_back(std::make_pair(Mul0, Mul1));
}
@@ -141,24 +157,40 @@ namespace {
/// Return the add instruction which is the root of the reduction.
Instruction *getRoot() { return Root; }
+ bool is64Bit() const { return Root->getType()->isIntegerTy(64); }
+
+ Type *getType() const { return Root->getType(); }
+
/// Return the incoming value to be accumulated. This maybe null.
Value *getAccumulator() { return Acc; }
/// Return the set of adds that comprise the reduction.
- SmallPtrSetImpl<Instruction*> &getAdds() { return Adds; }
+ SetVector<Instruction*> &getAdds() { return Adds; }
- /// Return the BinOpChain, rooted at mul instruction, that comprise the
+ /// Return the MulCandidate, rooted at mul instruction, that comprise the
/// the reduction.
- OpChainList &getMuls() { return Muls; }
+ MulCandList &getMuls() { return Muls; }
- /// Return the BinOpChain, rooted at mul instructions, that have been
+ /// Return the MulCandidate, rooted at mul instructions, that have been
/// paired for parallel execution.
- PMACPairList &getMulPairs() { return MulPairs; }
+ MulPairList &getMulPairs() { return MulPairs; }
/// To finalise, replace the uses of the root with the intrinsic call.
void UpdateRoot(Instruction *SMLAD) {
Root->replaceAllUsesWith(SMLAD);
}
+
+ void dump() {
+ LLVM_DEBUG(dbgs() << "Reduction:\n";
+ for (auto *Add : Adds)
+ LLVM_DEBUG(dbgs() << *Add << "\n");
+ for (auto &Mul : Muls)
+ LLVM_DEBUG(dbgs() << *Mul->Root << "\n"
+ << " " << *Mul->LHS << "\n"
+ << " " << *Mul->RHS << "\n");
+ LLVM_DEBUG(if (Acc) dbgs() << "Acc in: " << *Acc << "\n")
+ );
+ }
};
class WidenedLoad {
@@ -176,13 +208,11 @@ namespace {
}
};
- class ARMParallelDSP : public LoopPass {
+ class ARMParallelDSP : public FunctionPass {
ScalarEvolution *SE;
AliasAnalysis *AA;
TargetLibraryInfo *TLI;
DominatorTree *DT;
- LoopInfo *LI;
- Loop *L;
const DataLayout *DL;
Module *M;
std::map<LoadInst*, LoadInst*> LoadPairs;
@@ -190,13 +220,12 @@ namespace {
std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
template<unsigned>
- bool IsNarrowSequence(Value *V, ValueList &VL);
-
+ bool IsNarrowSequence(Value *V);
+ bool Search(Value *V, BasicBlock *BB, Reduction &R);
bool RecordMemoryOps(BasicBlock *BB);
void InsertParallelMACs(Reduction &Reduction);
bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
- LoadInst* CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
- IntegerType *LoadTy);
+ LoadInst* CreateWideLoad(MemInstList &Loads, IntegerType *LoadTy);
bool CreateParallelPairs(Reduction &R);
/// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
@@ -204,60 +233,38 @@ namespace {
/// products to a 32-bit accumulate operand. Optionally, the instruction can
/// exchange the halfwords of the second operand before performing the
/// arithmetic.
- bool MatchSMLAD(Loop *L);
+ bool MatchSMLAD(Function &F);
public:
static char ID;
- ARMParallelDSP() : LoopPass(ID) { }
-
- bool doInitialization(Loop *L, LPPassManager &LPM) override {
- LoadPairs.clear();
- WideLoads.clear();
- return true;
- }
+ ARMParallelDSP() : FunctionPass(ID) { }
void getAnalysisUsage(AnalysisUsage &AU) const override {
- LoopPass::getAnalysisUsage(AU);
+ FunctionPass::getAnalysisUsage(AU);
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<ScalarEvolutionWrapperPass>();
AU.addRequired<AAResultsWrapperPass>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<TargetPassConfig>();
- AU.addPreserved<LoopInfoWrapperPass>();
+ AU.addPreserved<ScalarEvolutionWrapperPass>();
+ AU.addPreserved<GlobalsAAWrapperPass>();
AU.setPreservesCFG();
}
- bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
+ bool runOnFunction(Function &F) override {
if (DisableParallelDSP)
return false;
- L = TheLoop;
+ if (skipFunction(F))
+ return false;
+
SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
- TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
+ TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
auto &TPC = getAnalysis<TargetPassConfig>();
- BasicBlock *Header = TheLoop->getHeader();
- if (!Header)
- return false;
-
- // TODO: We assume the loop header and latch to be the same block.
- // This is not a fundamental restriction, but lifting this would just
- // require more work to do the transformation and then patch up the CFG.
- if (Header != TheLoop->getLoopLatch()) {
- LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
- "running pass ARMParallelDSP\n");
- return false;
- }
-
- if (!TheLoop->getLoopPreheader())
- InsertPreheaderForLoop(L, DT, LI, nullptr, true);
-
- Function &F = *Header->getParent();
M = F.getParent();
DL = &M->getDataLayout();
@@ -282,17 +289,10 @@ namespace {
return false;
}
- LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
-
LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
- if (!RecordMemoryOps(Header)) {
- LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
- return false;
- }
-
- bool Changes = MatchSMLAD(L);
+ bool Changes = MatchSMLAD(F);
return Changes;
}
};
@@ -331,40 +331,14 @@ bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
// TODO: we currently only collect i16, and will support i8 later, so that's
// why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
template<unsigned MaxBitWidth>
-bool ARMParallelDSP::IsNarrowSequence(Value *V, ValueList &VL) {
- ConstantInt *CInt;
-
- if (match(V, m_ConstantInt(CInt))) {
- // TODO: if a constant is used, it needs to fit within the bit width.
- return false;
- }
-
- auto *I = dyn_cast<Instruction>(V);
- if (!I)
- return false;
-
- Value *Val, *LHS, *RHS;
- if (match(V, m_Trunc(m_Value(Val)))) {
- if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
- return IsNarrowSequence<MaxBitWidth>(Val, VL);
- } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
- // TODO: we need to implement sadd16/sadd8 for this, which enables to
- // also do the rewrite for smlad8.ll, but it is unsupported for now.
- return false;
- } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
- if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
+bool ARMParallelDSP::IsNarrowSequence(Value *V) {
+ if (auto *SExt = dyn_cast<SExtInst>(V)) {
+ if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
return false;
- if (match(Val, m_Load(m_Value()))) {
- auto *Ld = cast<LoadInst>(Val);
-
- // Check that these load could be paired.
- if (!LoadPairs.count(Ld) && !OffsetLoads.count(Ld))
- return false;
-
- VL.push_back(Val);
- VL.push_back(I);
- return true;
+ if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
+ // Check that this load could be paired.
+ return LoadPairs.count(Ld) || OffsetLoads.count(Ld);
}
}
return false;
@@ -375,6 +349,9 @@ bool ARMParallelDSP::IsNarrowSequence(Value *V, ValueList &VL) {
bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
SmallVector<LoadInst*, 8> Loads;
SmallVector<Instruction*, 8> Writes;
+ LoadPairs.clear();
+ WideLoads.clear();
+ OrderedBasicBlock OrderedBB(BB);
// Collect loads and instruction that may write to memory. For now we only
// record loads which are simple, sign-extended and have a single user.
@@ -389,21 +366,24 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
Loads.push_back(Ld);
}
+ if (Loads.empty() || Loads.size() > NumLoadLimit)
+ return false;
+
using InstSet = std::set<Instruction*>;
using DepMap = std::map<Instruction*, InstSet>;
DepMap RAWDeps;
// Record any writes that may alias a load.
const auto Size = LocationSize::unknown();
- for (auto Read : Loads) {
- for (auto Write : Writes) {
+ for (auto Write : Writes) {
+ for (auto Read : Loads) {
MemoryLocation ReadLoc =
MemoryLocation(Read->getPointerOperand(), Size);
if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
ModRefInfo::ModRef)))
continue;
- if (DT->dominates(Write, Read))
+ if (OrderedBB.dominates(Write, Read))
RAWDeps[Read].insert(Write);
}
}
@@ -411,17 +391,16 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
// Check whether there's not a write between the two loads which would
// prevent them from being safely merged.
auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
- LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
- LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
+ LoadInst *Dominator = OrderedBB.dominates(Base, Offset) ? Base : Offset;
+ LoadInst *Dominated = OrderedBB.dominates(Base, Offset) ? Offset : Base;
if (RAWDeps.count(Dominated)) {
InstSet &WritesBefore = RAWDeps[Dominated];
for (auto Before : WritesBefore) {
-
// We can't move the second load backward, past a write, to merge
// with the first load.
- if (DT->dominates(Dominator, Before))
+ if (OrderedBB.dominates(Dominator, Before))
return false;
}
}
@@ -431,7 +410,7 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
// Record base, offset load pairs.
for (auto *Base : Loads) {
for (auto *Offset : Loads) {
- if (Base == Offset)
+ if (Base == Offset || OffsetLoads.count(Offset))
continue;
if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
@@ -453,7 +432,54 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
return LoadPairs.size() > 1;
}
-// Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
+// Search recursively back through the operands to find a tree of values that
+// form a multiply-accumulate chain. The search records the Add and Mul
+// instructions that form the reduction and allows us to find a single value
+// to be used as the initial input to the accumlator.
+bool ARMParallelDSP::Search(Value *V, BasicBlock *BB, Reduction &R) {
+ // If we find a non-instruction, try to use it as the initial accumulator
+ // value. This may have already been found during the search in which case
+ // this function will return false, signaling a search fail.
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return R.InsertAcc(V);
+
+ if (I->getParent() != BB)
+ return false;
+
+ switch (I->getOpcode()) {
+ default:
+ break;
+ case Instruction::PHI:
+ // Could be the accumulator value.
+ return R.InsertAcc(V);
+ case Instruction::Add: {
+ // Adds should be adding together two muls, or another add and a mul to
+ // be within the mac chain. One of the operands may also be the
+ // accumulator value at which point we should stop searching.
+ R.InsertAdd(I);
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ bool ValidLHS = Search(LHS, BB, R);
+ bool ValidRHS = Search(RHS, BB, R);
+
+ if (ValidLHS && ValidRHS)
+ return true;
+
+ return R.InsertAcc(I);
+ }
+ case Instruction::Mul: {
+ Value *MulOp0 = I->getOperand(0);
+ Value *MulOp1 = I->getOperand(1);
+ return IsNarrowSequence<16>(MulOp0) && IsNarrowSequence<16>(MulOp1);
+ }
+ case Instruction::SExt:
+ return Search(I->getOperand(0), BB, R);
+ }
+ return false;
+}
+
+// The pass needs to identify integer add/sub reductions of 16-bit vector
// multiplications.
// To use SMLAD:
// 1) we first need to find integer add then look for this pattern:
@@ -484,88 +510,39 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
// If loop invariants are used instead of loads, these need to be packed
// before the loop begins.
//
-bool ARMParallelDSP::MatchSMLAD(Loop *L) {
- // Search recursively back through the operands to find a tree of values that
- // form a multiply-accumulate chain. The search records the Add and Mul
- // instructions that form the reduction and allows us to find a single value
- // to be used as the initial input to the accumlator.
- std::function<bool(Value*, Reduction&)> Search = [&]
- (Value *V, Reduction &R) -> bool {
-
- // If we find a non-instruction, try to use it as the initial accumulator
- // value. This may have already been found during the search in which case
- // this function will return false, signaling a search fail.
- auto *I = dyn_cast<Instruction>(V);
- if (!I)
- return R.InsertAcc(V);
-
- switch (I->getOpcode()) {
- default:
- break;
- case Instruction::PHI:
- // Could be the accumulator value.
- return R.InsertAcc(V);
- case Instruction::Add: {
- // Adds should be adding together two muls, or another add and a mul to
- // be within the mac chain. One of the operands may also be the
- // accumulator value at which point we should stop searching.
- bool ValidLHS = Search(I->getOperand(0), R);
- bool ValidRHS = Search(I->getOperand(1), R);
- if (!ValidLHS && !ValidLHS)
- return false;
- else if (ValidLHS && ValidRHS) {
- R.InsertAdd(I);
- return true;
- } else {
- R.InsertAdd(I);
- return R.InsertAcc(I);
- }
- }
- case Instruction::Mul: {
- Value *MulOp0 = I->getOperand(0);
- Value *MulOp1 = I->getOperand(1);
- if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) {
- ValueList LHS;
- ValueList RHS;
- if (IsNarrowSequence<16>(MulOp0, LHS) &&
- IsNarrowSequence<16>(MulOp1, RHS)) {
- R.InsertMul(I, LHS, RHS);
- return true;
- }
- }
- return false;
- }
- case Instruction::SExt:
- return Search(I->getOperand(0), R);
- }
- return false;
- };
-
+bool ARMParallelDSP::MatchSMLAD(Function &F) {
bool Changed = false;
- SmallPtrSet<Instruction*, 4> AllAdds;
- BasicBlock *Latch = L->getLoopLatch();
- for (Instruction &I : reverse(*Latch)) {
- if (I.getOpcode() != Instruction::Add)
+ for (auto &BB : F) {
+ SmallPtrSet<Instruction*, 4> AllAdds;
+ if (!RecordMemoryOps(&BB))
continue;
- if (AllAdds.count(&I))
- continue;
+ for (Instruction &I : reverse(BB)) {
+ if (I.getOpcode() != Instruction::Add)
+ continue;
- const auto *Ty = I.getType();
- if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
- continue;
+ if (AllAdds.count(&I))
+ continue;
- Reduction R(&I);
- if (!Search(&I, R))
- continue;
+ const auto *Ty = I.getType();
+ if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
+ continue;
- if (!CreateParallelPairs(R))
- continue;
+ Reduction R(&I);
+ if (!Search(&I, &BB, R))
+ continue;
- InsertParallelMACs(R);
- Changed = true;
- AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
+ R.InsertMuls();
+ LLVM_DEBUG(dbgs() << "After search, Reduction:\n"; R.dump());
+
+ if (!CreateParallelPairs(R))
+ continue;
+
+ InsertParallelMACs(R);
+ Changed = true;
+ AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
+ }
}
return Changed;
@@ -578,87 +555,57 @@ bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
return false;
// Check that the muls operate directly upon sign extended loads.
- for (auto &MulChain : R.getMuls()) {
- // A mul has 2 operands, and a narrow op consist of sext and a load; thus
- // we expect at least 4 items in this operand value list.
- if (MulChain->size() < 4) {
- LLVM_DEBUG(dbgs() << "Operand list too short.\n");
+ for (auto &MulCand : R.getMuls()) {
+ if (!MulCand->HasTwoLoadInputs())
return false;
- }
- MulChain->PopulateLoads();
- ValueList &LHS = static_cast<BinOpChain*>(MulChain.get())->LHS;
- ValueList &RHS = static_cast<BinOpChain*>(MulChain.get())->RHS;
-
- // Use +=2 to skip over the expected extend instructions.
- for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
- if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
- return false;
- }
}
- auto CanPair = [&](Reduction &R, BinOpChain *PMul0, BinOpChain *PMul1) {
- if (!PMul0->AreSymmetrical(PMul1))
- return false;
-
+ auto CanPair = [&](Reduction &R, MulCandidate *PMul0, MulCandidate *PMul1) {
// The first elements of each vector should be loads with sexts. If we
// find that its two pairs of consecutive loads, then these can be
// transformed into two wider loads and the users can be replaced with
// DSP intrinsics.
- for (unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
- auto *Ld0 = dyn_cast<LoadInst>(PMul0->LHS[x]);
- auto *Ld1 = dyn_cast<LoadInst>(PMul1->LHS[x]);
- auto *Ld2 = dyn_cast<LoadInst>(PMul0->RHS[x]);
- auto *Ld3 = dyn_cast<LoadInst>(PMul1->RHS[x]);
-
- if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
- return false;
+ auto Ld0 = static_cast<LoadInst*>(PMul0->LHS);
+ auto Ld1 = static_cast<LoadInst*>(PMul1->LHS);
+ auto Ld2 = static_cast<LoadInst*>(PMul0->RHS);
+ auto Ld3 = static_cast<LoadInst*>(PMul1->RHS);
- LLVM_DEBUG(dbgs() << "Loads:\n"
- << " - " << *Ld0 << "\n"
- << " - " << *Ld1 << "\n"
- << " - " << *Ld2 << "\n"
- << " - " << *Ld3 << "\n");
-
- if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
- if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
- LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
- R.AddMulPair(PMul0, PMul1);
- return true;
- } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
- LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
- LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n");
- PMul1->Exchange = true;
- R.AddMulPair(PMul0, PMul1);
- return true;
- }
- } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
- AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
+ if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
+ if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
- LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n");
- LLVM_DEBUG(dbgs() << " and swapping muls\n");
- PMul0->Exchange = true;
- // Only the second operand can be exchanged, so swap the muls.
- R.AddMulPair(PMul1, PMul0);
+ R.AddMulPair(PMul0, PMul1);
+ return true;
+ } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
+ LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
+ LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n");
+ R.AddMulPair(PMul0, PMul1, true);
return true;
}
+ } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
+ AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
+ LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
+ LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n");
+ LLVM_DEBUG(dbgs() << " and swapping muls\n");
+ // Only the second operand can be exchanged, so swap the muls.
+ R.AddMulPair(PMul1, PMul0, true);
+ return true;
}
return false;
};
- OpChainList &Muls = R.getMuls();
+ MulCandList &Muls = R.getMuls();
const unsigned Elems = Muls.size();
- SmallPtrSet<const Instruction*, 4> Paired;
for (unsigned i = 0; i < Elems; ++i) {
- BinOpChain *PMul0 = static_cast<BinOpChain*>(Muls[i].get());
- if (Paired.count(PMul0->Root))
+ MulCandidate *PMul0 = static_cast<MulCandidate*>(Muls[i].get());
+ if (PMul0->Paired)
continue;
for (unsigned j = 0; j < Elems; ++j) {
if (i == j)
continue;
- BinOpChain *PMul1 = static_cast<BinOpChain*>(Muls[j].get());
- if (Paired.count(PMul1->Root))
+ MulCandidate *PMul1 = static_cast<MulCandidate*>(Muls[j].get());
+ if (PMul1->Paired)
continue;
const Instruction *Mul0 = PMul0->Root;
@@ -668,29 +615,19 @@ bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
assert(PMul0 != PMul1 && "expected different chains");
- if (CanPair(R, PMul0, PMul1)) {
- Paired.insert(Mul0);
- Paired.insert(Mul1);
+ if (CanPair(R, PMul0, PMul1))
break;
- }
}
}
return !R.getMulPairs().empty();
}
-
void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
- auto CreateSMLADCall = [&](SmallVectorImpl<LoadInst*> &VecLd0,
- SmallVectorImpl<LoadInst*> &VecLd1,
- Value *Acc, bool Exchange,
- Instruction *InsertAfter) {
+ auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1,
+ Value *Acc, bool Exchange,
+ Instruction *InsertAfter) {
// Replace the reduction chain with an intrinsic call
- IntegerType *Ty = IntegerType::get(M->getContext(), 32);
- LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ?
- WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty);
- LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ?
- WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty);
Value* Args[] = { WideLd0, WideLd1, Acc };
Function *SMLAD = nullptr;
@@ -704,34 +641,95 @@ void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
- ++BasicBlock::iterator(InsertAfter));
+ BasicBlock::iterator(InsertAfter));
Instruction *Call = Builder.CreateCall(SMLAD, Args);
NumSMLAD++;
return Call;
};
- Instruction *InsertAfter = R.getRoot();
+ // Return the instruction after the dominated instruction.
+ auto GetInsertPoint = [this](Value *A, Value *B) {
+ assert((isa<Instruction>(A) || isa<Instruction>(B)) &&
+ "expected at least one instruction");
+
+ Value *V = nullptr;
+ if (!isa<Instruction>(A))
+ V = B;
+ else if (!isa<Instruction>(B))
+ V = A;
+ else
+ V = DT->dominates(cast<Instruction>(A), cast<Instruction>(B)) ? B : A;
+
+ return &*++BasicBlock::iterator(cast<Instruction>(V));
+ };
+
Value *Acc = R.getAccumulator();
- if (!Acc)
- Acc = ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
- LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n"
- << "Acc: " << *Acc << "\n");
+ // For any muls that were discovered but not paired, accumulate their values
+ // as before.
+ IRBuilder<NoFolder> Builder(R.getRoot()->getParent());
+ MulCandList &MulCands = R.getMuls();
+ for (auto &MulCand : MulCands) {
+ if (MulCand->Paired)
+ continue;
+
+ Instruction *Mul = cast<Instruction>(MulCand->Root);
+ LLVM_DEBUG(dbgs() << "Accumulating unpaired mul: " << *Mul << "\n");
+
+ if (R.getType() != Mul->getType()) {
+ assert(R.is64Bit() && "expected 64-bit result");
+ Builder.SetInsertPoint(&*++BasicBlock::iterator(Mul));
+ Mul = cast<Instruction>(Builder.CreateSExt(Mul, R.getRoot()->getType()));
+ }
+
+ if (!Acc) {
+ Acc = Mul;
+ continue;
+ }
+
+ // If Acc is the original incoming value to the reduction, it could be a
+ // phi. But the phi will dominate Mul, meaning that Mul will be the
+ // insertion point.
+ Builder.SetInsertPoint(GetInsertPoint(Mul, Acc));
+ Acc = Builder.CreateAdd(Mul, Acc);
+ }
+
+ if (!Acc) {
+ Acc = R.is64Bit() ?
+ ConstantInt::get(IntegerType::get(M->getContext(), 64), 0) :
+ ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
+ } else if (Acc->getType() != R.getType()) {
+ Builder.SetInsertPoint(R.getRoot());
+ Acc = Builder.CreateSExt(Acc, R.getType());
+ }
+
+ // Roughly sort the mul pairs in their program order.
+ OrderedBasicBlock OrderedBB(R.getRoot()->getParent());
+ llvm::sort(R.getMulPairs(), [&OrderedBB](auto &PairA, auto &PairB) {
+ const Instruction *A = PairA.first->Root;
+ const Instruction *B = PairB.first->Root;
+ return OrderedBB.dominates(A, B);
+ });
+
+ IntegerType *Ty = IntegerType::get(M->getContext(), 32);
for (auto &Pair : R.getMulPairs()) {
- BinOpChain *PMul0 = Pair.first;
- BinOpChain *PMul1 = Pair.second;
- LLVM_DEBUG(dbgs() << "Muls:\n"
- << "- " << *PMul0->Root << "\n"
- << "- " << *PMul1->Root << "\n");
-
- Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange,
- InsertAfter);
- InsertAfter = cast<Instruction>(Acc);
+ MulCandidate *LHSMul = Pair.first;
+ MulCandidate *RHSMul = Pair.second;
+ LoadInst *BaseLHS = LHSMul->getBaseLoad();
+ LoadInst *BaseRHS = RHSMul->getBaseLoad();
+ LoadInst *WideLHS = WideLoads.count(BaseLHS) ?
+ WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty);
+ LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
+ WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);
+
+ Instruction *InsertAfter = GetInsertPoint(WideLHS, WideRHS);
+ InsertAfter = GetInsertPoint(InsertAfter, Acc);
+ Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
}
R.UpdateRoot(cast<Instruction>(Acc));
}
-LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
+LoadInst* ARMParallelDSP::CreateWideLoad(MemInstList &Loads,
IntegerType *LoadTy) {
assert(Loads.size() == 2 && "currently only support widening two loads");
@@ -758,8 +756,8 @@ LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
return;
Source->moveBefore(Sink);
- for (auto &U : Source->uses())
- MoveBefore(Source, U.getUser());
+ for (auto &Op : Source->operands())
+ MoveBefore(Op, Source);
};
// Insert the load at the point of the original dominating load.
@@ -784,57 +782,30 @@ LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
// Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
// TODO: Support big-endian as well.
Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
- BaseSExt->setOperand(0, Bottom);
+ Value *NewBaseSExt = IRB.CreateSExt(Bottom, BaseSExt->getType());
+ BaseSExt->replaceAllUsesWith(NewBaseSExt);
IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
- OffsetSExt->setOperand(0, Trunc);
-
+ Value *NewOffsetSExt = IRB.CreateSExt(Trunc, OffsetSExt->getType());
+ OffsetSExt->replaceAllUsesWith(NewOffsetSExt);
+
+ LLVM_DEBUG(dbgs() << "From Base and Offset:\n"
+ << *Base << "\n" << *Offset << "\n"
+ << "Created Wide Load:\n"
+ << *WideLoad << "\n"
+ << *Bottom << "\n"
+ << *NewBaseSExt << "\n"
+ << *Top << "\n"
+ << *Trunc << "\n"
+ << *NewOffsetSExt << "\n");
WideLoads.emplace(std::make_pair(Base,
- make_unique<WidenedLoad>(Loads, WideLoad)));
+ std::make_unique<WidenedLoad>(Loads, WideLoad)));
return WideLoad;
}
-// Compare the value lists in Other to this chain.
-bool BinOpChain::AreSymmetrical(BinOpChain *Other) {
- // Element-by-element comparison of Value lists returning true if they are
- // instructions with the same opcode or constants with the same value.
- auto CompareValueList = [](const ValueList &VL0,
- const ValueList &VL1) {
- if (VL0.size() != VL1.size()) {
- LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
- << VL0.size() << " != " << VL1.size() << "\n");
- return false;
- }
-
- const unsigned Pairs = VL0.size();
-
- for (unsigned i = 0; i < Pairs; ++i) {
- const Value *V0 = VL0[i];
- const Value *V1 = VL1[i];
- const auto *Inst0 = dyn_cast<Instruction>(V0);
- const auto *Inst1 = dyn_cast<Instruction>(V1);
-
- if (!Inst0 || !Inst1)
- return false;
-
- if (Inst0->isSameOperationAs(Inst1))
- continue;
-
- const APInt *C0, *C1;
- if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
- return false;
- }
-
- return true;
- };
-
- return CompareValueList(LHS, Other->LHS) &&
- CompareValueList(RHS, Other->RHS);
-}
-
Pass *llvm::createARMParallelDSPPass() {
return new ARMParallelDSP();
}
@@ -842,6 +813,6 @@ Pass *llvm::createARMParallelDSPPass() {
char ARMParallelDSP::ID = 0;
INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
- "Transform loops to use DSP intrinsics", false, false)
+ "Transform functions to use DSP intrinsics", false, false)
INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
- "Transform loops to use DSP intrinsics", false, false)
+ "Transform functions to use DSP intrinsics", false, false)