summaryrefslogtreecommitdiff
path: root/lib/CodeGen/RegUsageInfoPropagate.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/CodeGen/RegUsageInfoPropagate.cpp')
-rw-r--r--lib/CodeGen/RegUsageInfoPropagate.cpp35
1 files changed, 26 insertions, 9 deletions
diff --git a/lib/CodeGen/RegUsageInfoPropagate.cpp b/lib/CodeGen/RegUsageInfoPropagate.cpp
index 5cc35bfeca63..5b12d00e126f 100644
--- a/lib/CodeGen/RegUsageInfoPropagate.cpp
+++ b/lib/CodeGen/RegUsageInfoPropagate.cpp
@@ -21,6 +21,7 @@
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/Passes.h"
@@ -87,14 +88,31 @@ void RegUsageInfoPropagationPass::getAnalysisUsage(AnalysisUsage &AU) const {
MachineFunctionPass::getAnalysisUsage(AU);
}
+// Assumes call instructions have a single reference to a function.
+static const Function *findCalledFunction(const Module &M, MachineInstr &MI) {
+ for (MachineOperand &MO : MI.operands()) {
+ if (MO.isGlobal())
+ return dyn_cast<Function>(MO.getGlobal());
+
+ if (MO.isSymbol())
+ return M.getFunction(MO.getSymbolName());
+ }
+
+ return nullptr;
+}
+
bool RegUsageInfoPropagationPass::runOnMachineFunction(MachineFunction &MF) {
- const Module *M = MF.getFunction()->getParent();
+ const Module *M = MF.getFunction().getParent();
PhysicalRegisterUsageInfo *PRUI = &getAnalysis<PhysicalRegisterUsageInfo>();
DEBUG(dbgs() << " ++++++++++++++++++++ " << getPassName()
<< " ++++++++++++++++++++ \n");
DEBUG(dbgs() << "MachineFunction : " << MF.getName() << "\n");
+ const MachineFrameInfo &MFI = MF.getFrameInfo();
+ if (!MFI.hasCalls() && !MFI.hasTailCall())
+ return false;
+
bool Changed = false;
for (MachineBasicBlock &MBB : MF) {
@@ -113,15 +131,14 @@ bool RegUsageInfoPropagationPass::runOnMachineFunction(MachineFunction &MF) {
Changed = true;
};
- MachineOperand &Operand = MI.getOperand(0);
- if (Operand.isGlobal())
- UpdateRegMask(cast<Function>(Operand.getGlobal()));
- else if (Operand.isSymbol())
- UpdateRegMask(M->getFunction(Operand.getSymbolName()));
+ if (const Function *F = findCalledFunction(*M, MI)) {
+ UpdateRegMask(F);
+ } else {
+ DEBUG(dbgs() << "Failed to find call target function\n");
+ }
- DEBUG(dbgs()
- << "Call Instruction After Register Usage Info Propagation : \n");
- DEBUG(dbgs() << MI << "\n");
+ DEBUG(dbgs() << "Call Instruction After Register Usage Info Propagation : "
+ << MI << '\n');
}
}