diff options
Diffstat (limited to 'lib/Target/BPF/BPFISelDAGToDAG.cpp')
| -rw-r--r-- | lib/Target/BPF/BPFISelDAGToDAG.cpp | 312 |
1 files changed, 235 insertions, 77 deletions
diff --git a/lib/Target/BPF/BPFISelDAGToDAG.cpp b/lib/Target/BPF/BPFISelDAGToDAG.cpp index c6ddd6bdad5e..f48429ee57b0 100644 --- a/lib/Target/BPF/BPFISelDAGToDAG.cpp +++ b/lib/Target/BPF/BPFISelDAGToDAG.cpp @@ -16,6 +16,7 @@ #include "BPFRegisterInfo.h" #include "BPFSubtarget.h" #include "BPFTargetMachine.h" +#include "llvm/CodeGen/FunctionLoweringInfo.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" @@ -57,6 +58,11 @@ private: bool SelectAddr(SDValue Addr, SDValue &Base, SDValue &Offset); bool SelectFIAddr(SDValue Addr, SDValue &Base, SDValue &Offset); + // Node preprocessing cases + void PreprocessLoad(SDNode *Node, SelectionDAG::allnodes_iterator I); + void PreprocessCopyToReg(SDNode *Node); + void PreprocessTrunc(SDNode *Node, SelectionDAG::allnodes_iterator I); + // Find constants from a constant structure typedef std::vector<unsigned char> val_vec_type; bool fillGenericConstant(const DataLayout &DL, const Constant *CV, @@ -69,9 +75,12 @@ private: val_vec_type &Vals, int Offset); bool getConstantFieldValue(const GlobalAddressSDNode *Node, uint64_t Offset, uint64_t Size, unsigned char *ByteSeq); + bool checkLoadDef(unsigned DefReg, unsigned match_load_op); // Mapping from ConstantStruct global value to corresponding byte-list values std::map<const void *, val_vec_type> cs_vals_; + // Mapping from vreg to load memory opcode + std::map<unsigned, unsigned> load_to_vreg_; }; } // namespace @@ -203,89 +212,110 @@ void BPFDAGToDAGISel::Select(SDNode *Node) { SelectCode(Node); } +void BPFDAGToDAGISel::PreprocessLoad(SDNode *Node, + SelectionDAG::allnodes_iterator I) { + union { + uint8_t c[8]; + uint16_t s; + uint32_t i; + uint64_t d; + } new_val; // hold up the constant values replacing loads. + bool to_replace = false; + SDLoc DL(Node); + const LoadSDNode *LD = cast<LoadSDNode>(Node); + uint64_t size = LD->getMemOperand()->getSize(); + + if (!size || size > 8 || (size & (size - 1))) + return; + + SDNode *LDAddrNode = LD->getOperand(1).getNode(); + // Match LDAddr against either global_addr or (global_addr + offset) + unsigned opcode = LDAddrNode->getOpcode(); + if (opcode == ISD::ADD) { + SDValue OP1 = LDAddrNode->getOperand(0); + SDValue OP2 = LDAddrNode->getOperand(1); + + // We want to find the pattern global_addr + offset + SDNode *OP1N = OP1.getNode(); + if (OP1N->getOpcode() <= ISD::BUILTIN_OP_END || OP1N->getNumOperands() == 0) + return; + + DEBUG(dbgs() << "Check candidate load: "; LD->dump(); dbgs() << '\n'); + + const GlobalAddressSDNode *GADN = + dyn_cast<GlobalAddressSDNode>(OP1N->getOperand(0).getNode()); + const ConstantSDNode *CDN = dyn_cast<ConstantSDNode>(OP2.getNode()); + if (GADN && CDN) + to_replace = + getConstantFieldValue(GADN, CDN->getZExtValue(), size, new_val.c); + } else if (LDAddrNode->getOpcode() > ISD::BUILTIN_OP_END && + LDAddrNode->getNumOperands() > 0) { + DEBUG(dbgs() << "Check candidate load: "; LD->dump(); dbgs() << '\n'); + + SDValue OP1 = LDAddrNode->getOperand(0); + if (const GlobalAddressSDNode *GADN = + dyn_cast<GlobalAddressSDNode>(OP1.getNode())) + to_replace = getConstantFieldValue(GADN, 0, size, new_val.c); + } + + if (!to_replace) + return; + + // replacing the old with a new value + uint64_t val; + if (size == 1) + val = new_val.c[0]; + else if (size == 2) + val = new_val.s; + else if (size == 4) + val = new_val.i; + else { + val = new_val.d; + } + + DEBUG(dbgs() << "Replacing load of size " << size << " with constant " << val + << '\n'); + SDValue NVal = CurDAG->getConstant(val, DL, MVT::i64); + + // After replacement, the current node is dead, we need to + // go backward one step to make iterator still work + I--; + SDValue From[] = {SDValue(Node, 0), SDValue(Node, 1)}; + SDValue To[] = {NVal, NVal}; + CurDAG->ReplaceAllUsesOfValuesWith(From, To, 2); + I++; + // It is safe to delete node now + CurDAG->DeleteNode(Node); +} + void BPFDAGToDAGISel::PreprocessISelDAG() { - // Iterate through all nodes, only interested in loads from ConstantStruct - // ConstantArray should have converted by IR->DAG processing + // Iterate through all nodes, interested in the following cases: + // + // . loads from ConstantStruct or ConstantArray of constructs + // which can be turns into constant itself, with this we can + // avoid reading from read-only section at runtime. + // + // . reg truncating is often the result of 8/16/32bit->64bit or + // 8/16bit->32bit conversion. If the reg value is loaded with + // masked byte width, the AND operation can be removed since + // BPF LOAD already has zero extension. + // + // This also solved a correctness issue. + // In BPF socket-related program, e.g., __sk_buff->{data, data_end} + // are 32-bit registers, but later on, kernel verifier will rewrite + // it with 64-bit value. Therefore, truncating the value after the + // load will result in incorrect code. for (SelectionDAG::allnodes_iterator I = CurDAG->allnodes_begin(), E = CurDAG->allnodes_end(); I != E;) { SDNode *Node = &*I++; unsigned Opcode = Node->getOpcode(); - if (Opcode != ISD::LOAD) - continue; - - union { - uint8_t c[8]; - uint16_t s; - uint32_t i; - uint64_t d; - } new_val; // hold up the constant values replacing loads. - bool to_replace = false; - SDLoc DL(Node); - const LoadSDNode *LD = cast<LoadSDNode>(Node); - uint64_t size = LD->getMemOperand()->getSize(); - if (!size || size > 8 || (size & (size - 1))) - continue; - - SDNode *LDAddrNode = LD->getOperand(1).getNode(); - // Match LDAddr against either global_addr or (global_addr + offset) - unsigned opcode = LDAddrNode->getOpcode(); - if (opcode == ISD::ADD) { - SDValue OP1 = LDAddrNode->getOperand(0); - SDValue OP2 = LDAddrNode->getOperand(1); - - // We want to find the pattern global_addr + offset - SDNode *OP1N = OP1.getNode(); - if (OP1N->getOpcode() <= ISD::BUILTIN_OP_END || - OP1N->getNumOperands() == 0) - continue; - - DEBUG(dbgs() << "Check candidate load: "; LD->dump(); dbgs() << '\n'); - - const GlobalAddressSDNode *GADN = - dyn_cast<GlobalAddressSDNode>(OP1N->getOperand(0).getNode()); - const ConstantSDNode *CDN = dyn_cast<ConstantSDNode>(OP2.getNode()); - if (GADN && CDN) - to_replace = - getConstantFieldValue(GADN, CDN->getZExtValue(), size, new_val.c); - } else if (LDAddrNode->getOpcode() > ISD::BUILTIN_OP_END && - LDAddrNode->getNumOperands() > 0) { - DEBUG(dbgs() << "Check candidate load: "; LD->dump(); dbgs() << '\n'); - - SDValue OP1 = LDAddrNode->getOperand(0); - if (const GlobalAddressSDNode *GADN = - dyn_cast<GlobalAddressSDNode>(OP1.getNode())) - to_replace = getConstantFieldValue(GADN, 0, size, new_val.c); - } - - if (!to_replace) - continue; - - // replacing the old with a new value - uint64_t val; - if (size == 1) - val = new_val.c[0]; - else if (size == 2) - val = new_val.s; - else if (size == 4) - val = new_val.i; - else { - val = new_val.d; - } - - DEBUG(dbgs() << "Replacing load of size " << size << " with constant " - << val << '\n'); - SDValue NVal = CurDAG->getConstant(val, DL, MVT::i64); - - // After replacement, the current node is dead, we need to - // go backward one step to make iterator still work - I--; - SDValue From[] = {SDValue(Node, 0), SDValue(Node, 1)}; - SDValue To[] = {NVal, NVal}; - CurDAG->ReplaceAllUsesOfValuesWith(From, To, 2); - I++; - // It is safe to delete node now - CurDAG->DeleteNode(Node); + if (Opcode == ISD::LOAD) + PreprocessLoad(Node, I); + else if (Opcode == ISD::CopyToReg) + PreprocessCopyToReg(Node); + else if (Opcode == ISD::AND) + PreprocessTrunc(Node, I); } } @@ -415,6 +445,134 @@ bool BPFDAGToDAGISel::fillConstantStruct(const DataLayout &DL, return true; } +void BPFDAGToDAGISel::PreprocessCopyToReg(SDNode *Node) { + const RegisterSDNode *RegN = dyn_cast<RegisterSDNode>(Node->getOperand(1)); + if (!RegN || !TargetRegisterInfo::isVirtualRegister(RegN->getReg())) + return; + + const LoadSDNode *LD = dyn_cast<LoadSDNode>(Node->getOperand(2)); + if (!LD) + return; + + // Assign a load value to a virtual register. record its load width + unsigned mem_load_op = 0; + switch (LD->getMemOperand()->getSize()) { + default: + return; + case 4: + mem_load_op = BPF::LDW; + break; + case 2: + mem_load_op = BPF::LDH; + break; + case 1: + mem_load_op = BPF::LDB; + break; + } + + DEBUG(dbgs() << "Find Load Value to VReg " + << TargetRegisterInfo::virtReg2Index(RegN->getReg()) << '\n'); + load_to_vreg_[RegN->getReg()] = mem_load_op; +} + +void BPFDAGToDAGISel::PreprocessTrunc(SDNode *Node, + SelectionDAG::allnodes_iterator I) { + ConstantSDNode *MaskN = dyn_cast<ConstantSDNode>(Node->getOperand(1)); + if (!MaskN) + return; + + unsigned match_load_op = 0; + switch (MaskN->getZExtValue()) { + default: + return; + case 0xFFFFFFFF: + match_load_op = BPF::LDW; + break; + case 0xFFFF: + match_load_op = BPF::LDH; + break; + case 0xFF: + match_load_op = BPF::LDB; + break; + } + + // The Reg operand should be a virtual register, which is defined + // outside the current basic block. DAG combiner has done a pretty + // good job in removing truncating inside a single basic block. + SDValue BaseV = Node->getOperand(0); + if (BaseV.getOpcode() != ISD::CopyFromReg) + return; + + const RegisterSDNode *RegN = + dyn_cast<RegisterSDNode>(BaseV.getNode()->getOperand(1)); + if (!RegN || !TargetRegisterInfo::isVirtualRegister(RegN->getReg())) + return; + unsigned AndOpReg = RegN->getReg(); + DEBUG(dbgs() << "Examine %vreg" << TargetRegisterInfo::virtReg2Index(AndOpReg) + << '\n'); + + // Examine the PHI insns in the MachineBasicBlock to found out the + // definitions of this virtual register. At this stage (DAG2DAG + // transformation), only PHI machine insns are available in the machine basic + // block. + MachineBasicBlock *MBB = FuncInfo->MBB; + MachineInstr *MII = nullptr; + for (auto &MI : *MBB) { + for (unsigned i = 0; i < MI.getNumOperands(); ++i) { + const MachineOperand &MOP = MI.getOperand(i); + if (!MOP.isReg() || !MOP.isDef()) + continue; + unsigned Reg = MOP.getReg(); + if (TargetRegisterInfo::isVirtualRegister(Reg) && Reg == AndOpReg) { + MII = &MI; + break; + } + } + } + + if (MII == nullptr) { + // No phi definition in this block. + if (!checkLoadDef(AndOpReg, match_load_op)) + return; + } else { + // The PHI node looks like: + // %vreg2<def> = PHI %vreg0, <BB#1>, %vreg1, <BB#3> + // Trace each incoming definition, e.g., (%vreg0, BB#1) and (%vreg1, BB#3) + // The AND operation can be removed if both %vreg0 in BB#1 and %vreg1 in + // BB#3 are defined with with a load matching the MaskN. + DEBUG(dbgs() << "Check PHI Insn: "; MII->dump(); dbgs() << '\n'); + unsigned PrevReg = -1; + for (unsigned i = 0; i < MII->getNumOperands(); ++i) { + const MachineOperand &MOP = MII->getOperand(i); + if (MOP.isReg()) { + if (MOP.isDef()) + continue; + PrevReg = MOP.getReg(); + if (!TargetRegisterInfo::isVirtualRegister(PrevReg)) + return; + if (!checkLoadDef(PrevReg, match_load_op)) + return; + } + } + } + + DEBUG(dbgs() << "Remove the redundant AND operation in: "; Node->dump(); + dbgs() << '\n'); + + I--; + CurDAG->ReplaceAllUsesWith(SDValue(Node, 0), BaseV); + I++; + CurDAG->DeleteNode(Node); +} + +bool BPFDAGToDAGISel::checkLoadDef(unsigned DefReg, unsigned match_load_op) { + auto it = load_to_vreg_.find(DefReg); + if (it == load_to_vreg_.end()) + return false; // The definition of register is not exported yet. + + return it->second == match_load_op; +} + FunctionPass *llvm::createBPFISelDag(BPFTargetMachine &TM) { return new BPFDAGToDAGISel(TM); } |
