diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64InstrInfo.cpp')
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 1037 |
1 files changed, 962 insertions, 75 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index 5c35e5bcdd30..54f3f7c10132 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -30,6 +30,7 @@ #include "llvm/CodeGen/StackMaps.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/GlobalValue.h" #include "llvm/MC/MCAsmInfo.h" @@ -1981,6 +1982,9 @@ bool AArch64InstrInfo::getMemOperandWithOffset(const MachineInstr &LdSt, const MachineOperand *&BaseOp, int64_t &Offset, const TargetRegisterInfo *TRI) const { + if (!LdSt.mayLoadOrStore()) + return false; + unsigned Width; return getMemOperandWithOffsetWidth(LdSt, BaseOp, Offset, Width, TRI); } @@ -2025,9 +2029,8 @@ bool AArch64InstrInfo::getMemOperandWithOffsetWidth( Offset = LdSt.getOperand(3).getImm() * Scale; } - assert((BaseOp->isReg() || BaseOp->isFI()) && - "getMemOperandWithOffset only supports base " - "operands of type register or frame index."); + if (!BaseOp->isReg() && !BaseOp->isFI()) + return false; return true; } @@ -2185,12 +2188,19 @@ bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, unsigned &Scale, MaxOffset = 4095; break; case AArch64::ADDG: - case AArch64::TAGPstack: Scale = 16; Width = 0; MinOffset = 0; MaxOffset = 63; break; + case AArch64::TAGPstack: + Scale = 16; + Width = 0; + // TAGP with a negative offset turns into SUBP, which has a maximum offset + // of 63 (not 64!). + MinOffset = -63; + MaxOffset = 63; + break; case AArch64::LDG: case AArch64::STGOffset: case AArch64::STZGOffset: @@ -2227,54 +2237,82 @@ bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, unsigned &Scale, return true; } -static unsigned getOffsetStride(unsigned Opc) { +// Scaling factor for unscaled load or store. +int AArch64InstrInfo::getMemScale(unsigned Opc) { switch (Opc) { default: - return 0; - case AArch64::LDURQi: - case AArch64::STURQi: - return 16; - case AArch64::LDURXi: - case AArch64::LDURDi: - case AArch64::STURXi: - case AArch64::STURDi: - return 8; - case AArch64::LDURWi: + llvm_unreachable("Opcode has unknown scale!"); + case AArch64::LDRBBui: + case AArch64::LDURBBi: + case AArch64::LDRSBWui: + case AArch64::LDURSBWi: + case AArch64::STRBBui: + case AArch64::STURBBi: + return 1; + case AArch64::LDRHHui: + case AArch64::LDURHHi: + case AArch64::LDRSHWui: + case AArch64::LDURSHWi: + case AArch64::STRHHui: + case AArch64::STURHHi: + return 2; + case AArch64::LDRSui: case AArch64::LDURSi: + case AArch64::LDRSWui: case AArch64::LDURSWi: - case AArch64::STURWi: + case AArch64::LDRWui: + case AArch64::LDURWi: + case AArch64::STRSui: case AArch64::STURSi: + case AArch64::STRWui: + case AArch64::STURWi: + case AArch64::LDPSi: + case AArch64::LDPSWi: + case AArch64::LDPWi: + case AArch64::STPSi: + case AArch64::STPWi: return 4; + case AArch64::LDRDui: + case AArch64::LDURDi: + case AArch64::LDRXui: + case AArch64::LDURXi: + case AArch64::STRDui: + case AArch64::STURDi: + case AArch64::STRXui: + case AArch64::STURXi: + case AArch64::LDPDi: + case AArch64::LDPXi: + case AArch64::STPDi: + case AArch64::STPXi: + return 8; + case AArch64::LDRQui: + case AArch64::LDURQi: + case AArch64::STRQui: + case AArch64::STURQi: + case AArch64::LDPQi: + case AArch64::STPQi: + case AArch64::STGOffset: + case AArch64::STZGOffset: + case AArch64::ST2GOffset: + case AArch64::STZ2GOffset: + case AArch64::STGPi: + return 16; } } // Scale the unscaled offsets. Returns false if the unscaled offset can't be // scaled. static bool scaleOffset(unsigned Opc, int64_t &Offset) { - unsigned OffsetStride = getOffsetStride(Opc); - if (OffsetStride == 0) - return false; + int Scale = AArch64InstrInfo::getMemScale(Opc); + // If the byte-offset isn't a multiple of the stride, we can't scale this // offset. - if (Offset % OffsetStride != 0) + if (Offset % Scale != 0) return false; // Convert the byte-offset used by unscaled into an "element" offset used // by the scaled pair load/store instructions. - Offset /= OffsetStride; - return true; -} - -// Unscale the scaled offsets. Returns false if the scaled offset can't be -// unscaled. -static bool unscaleOffset(unsigned Opc, int64_t &Offset) { - unsigned OffsetStride = getOffsetStride(Opc); - if (OffsetStride == 0) - return false; - - // Convert the "element" offset used by scaled pair load/store instructions - // into the byte-offset used by unscaled. - Offset *= OffsetStride; + Offset /= Scale; return true; } @@ -2305,15 +2343,17 @@ static bool shouldClusterFI(const MachineFrameInfo &MFI, int FI1, int64_t ObjectOffset1 = MFI.getObjectOffset(FI1); int64_t ObjectOffset2 = MFI.getObjectOffset(FI2); assert(ObjectOffset1 <= ObjectOffset2 && "Object offsets are not ordered."); - // Get the byte-offset from the object offset. - if (!unscaleOffset(Opcode1, Offset1) || !unscaleOffset(Opcode2, Offset2)) + // Convert to scaled object offsets. + int Scale1 = AArch64InstrInfo::getMemScale(Opcode1); + if (ObjectOffset1 % Scale1 != 0) return false; + ObjectOffset1 /= Scale1; + int Scale2 = AArch64InstrInfo::getMemScale(Opcode2); + if (ObjectOffset2 % Scale2 != 0) + return false; + ObjectOffset2 /= Scale2; ObjectOffset1 += Offset1; ObjectOffset2 += Offset2; - // Get the "element" index in the object. - if (!scaleOffset(Opcode1, ObjectOffset1) || - !scaleOffset(Opcode2, ObjectOffset2)) - return false; return ObjectOffset1 + 1 == ObjectOffset2; } @@ -2373,7 +2413,7 @@ bool AArch64InstrInfo::shouldClusterMemOps(const MachineOperand &BaseOp1, // The caller should already have ordered First/SecondLdSt by offset. // Note: except for non-equal frame index bases if (BaseOp1.isFI()) { - assert((!BaseOp1.isIdenticalTo(BaseOp2) || Offset1 >= Offset2) && + assert((!BaseOp1.isIdenticalTo(BaseOp2) || Offset1 <= Offset2) && "Caller should have ordered offsets."); const MachineFrameInfo &MFI = @@ -2382,8 +2422,7 @@ bool AArch64InstrInfo::shouldClusterMemOps(const MachineOperand &BaseOp1, BaseOp2.getIndex(), Offset2, SecondOpc); } - assert((!BaseOp1.isIdenticalTo(BaseOp2) || Offset1 <= Offset2) && - "Caller should have ordered offsets."); + assert(Offset1 <= Offset2 && "Caller should have ordered offsets."); return Offset1 + 1 == Offset2; } @@ -2409,8 +2448,8 @@ static bool forwardCopyWillClobberTuple(unsigned DestReg, unsigned SrcReg, void AArch64InstrInfo::copyPhysRegTuple(MachineBasicBlock &MBB, MachineBasicBlock::iterator I, - const DebugLoc &DL, unsigned DestReg, - unsigned SrcReg, bool KillSrc, + const DebugLoc &DL, MCRegister DestReg, + MCRegister SrcReg, bool KillSrc, unsigned Opcode, ArrayRef<unsigned> Indices) const { assert(Subtarget.hasNEON() && "Unexpected register copy without NEON"); @@ -2461,8 +2500,8 @@ void AArch64InstrInfo::copyGPRRegTuple(MachineBasicBlock &MBB, void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, MachineBasicBlock::iterator I, - const DebugLoc &DL, unsigned DestReg, - unsigned SrcReg, bool KillSrc) const { + const DebugLoc &DL, MCRegister DestReg, + MCRegister SrcReg, bool KillSrc) const { if (AArch64::GPR32spRegClass.contains(DestReg) && (AArch64::GPR32spRegClass.contains(SrcReg) || SrcReg == AArch64::WZR)) { const TargetRegisterInfo *TRI = &getRegisterInfo(); @@ -2471,10 +2510,10 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, // If either operand is WSP, expand to ADD #0. if (Subtarget.hasZeroCycleRegMove()) { // Cyclone recognizes "ADD Xd, Xn, #0" as a zero-cycle register move. - unsigned DestRegX = TRI->getMatchingSuperReg(DestReg, AArch64::sub_32, - &AArch64::GPR64spRegClass); - unsigned SrcRegX = TRI->getMatchingSuperReg(SrcReg, AArch64::sub_32, - &AArch64::GPR64spRegClass); + MCRegister DestRegX = TRI->getMatchingSuperReg( + DestReg, AArch64::sub_32, &AArch64::GPR64spRegClass); + MCRegister SrcRegX = TRI->getMatchingSuperReg( + SrcReg, AArch64::sub_32, &AArch64::GPR64spRegClass); // This instruction is reading and writing X registers. This may upset // the register scavenger and machine verifier, so we need to indicate // that we are reading an undefined value from SrcRegX, but a proper @@ -2497,10 +2536,10 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, } else { if (Subtarget.hasZeroCycleRegMove()) { // Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move. - unsigned DestRegX = TRI->getMatchingSuperReg(DestReg, AArch64::sub_32, - &AArch64::GPR64spRegClass); - unsigned SrcRegX = TRI->getMatchingSuperReg(SrcReg, AArch64::sub_32, - &AArch64::GPR64spRegClass); + MCRegister DestRegX = TRI->getMatchingSuperReg( + DestReg, AArch64::sub_32, &AArch64::GPR64spRegClass); + MCRegister SrcRegX = TRI->getMatchingSuperReg( + SrcReg, AArch64::sub_32, &AArch64::GPR64spRegClass); // This instruction is reading and writing X registers. This may upset // the register scavenger and machine verifier, so we need to indicate // that we are reading an undefined value from SrcRegX, but a proper @@ -2897,7 +2936,18 @@ void AArch64InstrInfo::storeRegToStackSlot( } break; } + unsigned StackID = TargetStackID::Default; + if (AArch64::PPRRegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register store without SVE"); + Opc = AArch64::STR_PXI; + StackID = TargetStackID::SVEVector; + } else if (AArch64::ZPRRegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register store without SVE"); + Opc = AArch64::STR_ZXI; + StackID = TargetStackID::SVEVector; + } assert(Opc && "Unknown register class"); + MFI.setStackID(FI, StackID); const MachineInstrBuilder MI = BuildMI(MBB, MBBI, DebugLoc(), get(Opc)) .addReg(SrcReg, getKillRegState(isKill)) @@ -3028,7 +3078,19 @@ void AArch64InstrInfo::loadRegFromStackSlot( } break; } + + unsigned StackID = TargetStackID::Default; + if (AArch64::PPRRegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register load without SVE"); + Opc = AArch64::LDR_PXI; + StackID = TargetStackID::SVEVector; + } else if (AArch64::ZPRRegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register load without SVE"); + Opc = AArch64::LDR_ZXI; + StackID = TargetStackID::SVEVector; + } assert(Opc && "Unknown register class"); + MFI.setStackID(FI, StackID); const MachineInstrBuilder MI = BuildMI(MBB, MBBI, DebugLoc(), get(Opc)) .addReg(DestReg, getDefRegState(true)) @@ -3085,7 +3147,7 @@ static void emitFrameOffsetAdj(MachineBasicBlock &MBB, const unsigned MaxEncodableValue = MaxEncoding << ShiftSize; do { - unsigned ThisVal = std::min<unsigned>(Offset, MaxEncodableValue); + uint64_t ThisVal = std::min<uint64_t>(Offset, MaxEncodableValue); unsigned LocalShiftSize = 0; if (ThisVal > MaxEncoding) { ThisVal = ThisVal >> ShiftSize; @@ -3548,6 +3610,18 @@ static bool isCombineInstrCandidate64(unsigned Opc) { // Note: MSUB Wd,Wn,Wm,Wi -> Wd = Wi - WnxWm, not Wd=WnxWm - Wi. case AArch64::SUBXri: case AArch64::SUBSXri: + case AArch64::ADDv8i8: + case AArch64::ADDv16i8: + case AArch64::ADDv4i16: + case AArch64::ADDv8i16: + case AArch64::ADDv2i32: + case AArch64::ADDv4i32: + case AArch64::SUBv8i8: + case AArch64::SUBv16i8: + case AArch64::SUBv4i16: + case AArch64::SUBv8i16: + case AArch64::SUBv2i32: + case AArch64::SUBv4i32: return true; default: break; @@ -3690,6 +3764,13 @@ static bool getMaddPatterns(MachineInstr &Root, } }; + auto setVFound = [&](int Opcode, int Operand, MachineCombinerPattern Pattern) { + if (canCombine(MBB, Root.getOperand(Operand), Opcode)) { + Patterns.push_back(Pattern); + Found = true; + } + }; + typedef MachineCombinerPattern MCP; switch (Opc) { @@ -3725,6 +3806,70 @@ static bool getMaddPatterns(MachineInstr &Root, case AArch64::SUBXri: setFound(AArch64::MADDXrrr, 1, AArch64::XZR, MCP::MULSUBXI_OP1); break; + case AArch64::ADDv8i8: + setVFound(AArch64::MULv8i8, 1, MCP::MULADDv8i8_OP1); + setVFound(AArch64::MULv8i8, 2, MCP::MULADDv8i8_OP2); + break; + case AArch64::ADDv16i8: + setVFound(AArch64::MULv16i8, 1, MCP::MULADDv16i8_OP1); + setVFound(AArch64::MULv16i8, 2, MCP::MULADDv16i8_OP2); + break; + case AArch64::ADDv4i16: + setVFound(AArch64::MULv4i16, 1, MCP::MULADDv4i16_OP1); + setVFound(AArch64::MULv4i16, 2, MCP::MULADDv4i16_OP2); + setVFound(AArch64::MULv4i16_indexed, 1, MCP::MULADDv4i16_indexed_OP1); + setVFound(AArch64::MULv4i16_indexed, 2, MCP::MULADDv4i16_indexed_OP2); + break; + case AArch64::ADDv8i16: + setVFound(AArch64::MULv8i16, 1, MCP::MULADDv8i16_OP1); + setVFound(AArch64::MULv8i16, 2, MCP::MULADDv8i16_OP2); + setVFound(AArch64::MULv8i16_indexed, 1, MCP::MULADDv8i16_indexed_OP1); + setVFound(AArch64::MULv8i16_indexed, 2, MCP::MULADDv8i16_indexed_OP2); + break; + case AArch64::ADDv2i32: + setVFound(AArch64::MULv2i32, 1, MCP::MULADDv2i32_OP1); + setVFound(AArch64::MULv2i32, 2, MCP::MULADDv2i32_OP2); + setVFound(AArch64::MULv2i32_indexed, 1, MCP::MULADDv2i32_indexed_OP1); + setVFound(AArch64::MULv2i32_indexed, 2, MCP::MULADDv2i32_indexed_OP2); + break; + case AArch64::ADDv4i32: + setVFound(AArch64::MULv4i32, 1, MCP::MULADDv4i32_OP1); + setVFound(AArch64::MULv4i32, 2, MCP::MULADDv4i32_OP2); + setVFound(AArch64::MULv4i32_indexed, 1, MCP::MULADDv4i32_indexed_OP1); + setVFound(AArch64::MULv4i32_indexed, 2, MCP::MULADDv4i32_indexed_OP2); + break; + case AArch64::SUBv8i8: + setVFound(AArch64::MULv8i8, 1, MCP::MULSUBv8i8_OP1); + setVFound(AArch64::MULv8i8, 2, MCP::MULSUBv8i8_OP2); + break; + case AArch64::SUBv16i8: + setVFound(AArch64::MULv16i8, 1, MCP::MULSUBv16i8_OP1); + setVFound(AArch64::MULv16i8, 2, MCP::MULSUBv16i8_OP2); + break; + case AArch64::SUBv4i16: + setVFound(AArch64::MULv4i16, 1, MCP::MULSUBv4i16_OP1); + setVFound(AArch64::MULv4i16, 2, MCP::MULSUBv4i16_OP2); + setVFound(AArch64::MULv4i16_indexed, 1, MCP::MULSUBv4i16_indexed_OP1); + setVFound(AArch64::MULv4i16_indexed, 2, MCP::MULSUBv4i16_indexed_OP2); + break; + case AArch64::SUBv8i16: + setVFound(AArch64::MULv8i16, 1, MCP::MULSUBv8i16_OP1); + setVFound(AArch64::MULv8i16, 2, MCP::MULSUBv8i16_OP2); + setVFound(AArch64::MULv8i16_indexed, 1, MCP::MULSUBv8i16_indexed_OP1); + setVFound(AArch64::MULv8i16_indexed, 2, MCP::MULSUBv8i16_indexed_OP2); + break; + case AArch64::SUBv2i32: + setVFound(AArch64::MULv2i32, 1, MCP::MULSUBv2i32_OP1); + setVFound(AArch64::MULv2i32, 2, MCP::MULSUBv2i32_OP2); + setVFound(AArch64::MULv2i32_indexed, 1, MCP::MULSUBv2i32_indexed_OP1); + setVFound(AArch64::MULv2i32_indexed, 2, MCP::MULSUBv2i32_indexed_OP2); + break; + case AArch64::SUBv4i32: + setVFound(AArch64::MULv4i32, 1, MCP::MULSUBv4i32_OP1); + setVFound(AArch64::MULv4i32, 2, MCP::MULSUBv4i32_OP2); + setVFound(AArch64::MULv4i32_indexed, 1, MCP::MULSUBv4i32_indexed_OP1); + setVFound(AArch64::MULv4i32_indexed, 2, MCP::MULSUBv4i32_indexed_OP2); + break; } return Found; } @@ -3937,6 +4082,46 @@ bool AArch64InstrInfo::isThroughputPattern( case MachineCombinerPattern::FMLSv2f64_OP2: case MachineCombinerPattern::FMLSv4i32_indexed_OP2: case MachineCombinerPattern::FMLSv4f32_OP2: + case MachineCombinerPattern::MULADDv8i8_OP1: + case MachineCombinerPattern::MULADDv8i8_OP2: + case MachineCombinerPattern::MULADDv16i8_OP1: + case MachineCombinerPattern::MULADDv16i8_OP2: + case MachineCombinerPattern::MULADDv4i16_OP1: + case MachineCombinerPattern::MULADDv4i16_OP2: + case MachineCombinerPattern::MULADDv8i16_OP1: + case MachineCombinerPattern::MULADDv8i16_OP2: + case MachineCombinerPattern::MULADDv2i32_OP1: + case MachineCombinerPattern::MULADDv2i32_OP2: + case MachineCombinerPattern::MULADDv4i32_OP1: + case MachineCombinerPattern::MULADDv4i32_OP2: + case MachineCombinerPattern::MULSUBv8i8_OP1: + case MachineCombinerPattern::MULSUBv8i8_OP2: + case MachineCombinerPattern::MULSUBv16i8_OP1: + case MachineCombinerPattern::MULSUBv16i8_OP2: + case MachineCombinerPattern::MULSUBv4i16_OP1: + case MachineCombinerPattern::MULSUBv4i16_OP2: + case MachineCombinerPattern::MULSUBv8i16_OP1: + case MachineCombinerPattern::MULSUBv8i16_OP2: + case MachineCombinerPattern::MULSUBv2i32_OP1: + case MachineCombinerPattern::MULSUBv2i32_OP2: + case MachineCombinerPattern::MULSUBv4i32_OP1: + case MachineCombinerPattern::MULSUBv4i32_OP2: + case MachineCombinerPattern::MULADDv4i16_indexed_OP1: + case MachineCombinerPattern::MULADDv4i16_indexed_OP2: + case MachineCombinerPattern::MULADDv8i16_indexed_OP1: + case MachineCombinerPattern::MULADDv8i16_indexed_OP2: + case MachineCombinerPattern::MULADDv2i32_indexed_OP1: + case MachineCombinerPattern::MULADDv2i32_indexed_OP2: + case MachineCombinerPattern::MULADDv4i32_indexed_OP1: + case MachineCombinerPattern::MULADDv4i32_indexed_OP2: + case MachineCombinerPattern::MULSUBv4i16_indexed_OP1: + case MachineCombinerPattern::MULSUBv4i16_indexed_OP2: + case MachineCombinerPattern::MULSUBv8i16_indexed_OP1: + case MachineCombinerPattern::MULSUBv8i16_indexed_OP2: + case MachineCombinerPattern::MULSUBv2i32_indexed_OP1: + case MachineCombinerPattern::MULSUBv2i32_indexed_OP2: + case MachineCombinerPattern::MULSUBv4i32_indexed_OP1: + case MachineCombinerPattern::MULSUBv4i32_indexed_OP2: return true; } // end switch (Pattern) return false; @@ -4040,6 +4225,80 @@ genFusedMultiply(MachineFunction &MF, MachineRegisterInfo &MRI, return MUL; } +/// genFusedMultiplyAcc - Helper to generate fused multiply accumulate +/// instructions. +/// +/// \see genFusedMultiply +static MachineInstr *genFusedMultiplyAcc( + MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, + MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs, + unsigned IdxMulOpd, unsigned MaddOpc, const TargetRegisterClass *RC) { + return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC, + FMAInstKind::Accumulator); +} + +/// genNeg - Helper to generate an intermediate negation of the second operand +/// of Root +static Register genNeg(MachineFunction &MF, MachineRegisterInfo &MRI, + const TargetInstrInfo *TII, MachineInstr &Root, + SmallVectorImpl<MachineInstr *> &InsInstrs, + DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, + unsigned MnegOpc, const TargetRegisterClass *RC) { + Register NewVR = MRI.createVirtualRegister(RC); + MachineInstrBuilder MIB = + BuildMI(MF, Root.getDebugLoc(), TII->get(MnegOpc), NewVR) + .add(Root.getOperand(2)); + InsInstrs.push_back(MIB); + + assert(InstrIdxForVirtReg.empty()); + InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); + + return NewVR; +} + +/// genFusedMultiplyAccNeg - Helper to generate fused multiply accumulate +/// instructions with an additional negation of the accumulator +static MachineInstr *genFusedMultiplyAccNeg( + MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, + MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs, + DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, unsigned IdxMulOpd, + unsigned MaddOpc, unsigned MnegOpc, const TargetRegisterClass *RC) { + assert(IdxMulOpd == 1); + + Register NewVR = + genNeg(MF, MRI, TII, Root, InsInstrs, InstrIdxForVirtReg, MnegOpc, RC); + return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC, + FMAInstKind::Accumulator, &NewVR); +} + +/// genFusedMultiplyIdx - Helper to generate fused multiply accumulate +/// instructions. +/// +/// \see genFusedMultiply +static MachineInstr *genFusedMultiplyIdx( + MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, + MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs, + unsigned IdxMulOpd, unsigned MaddOpc, const TargetRegisterClass *RC) { + return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC, + FMAInstKind::Indexed); +} + +/// genFusedMultiplyAccNeg - Helper to generate fused multiply accumulate +/// instructions with an additional negation of the accumulator +static MachineInstr *genFusedMultiplyIdxNeg( + MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, + MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs, + DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, unsigned IdxMulOpd, + unsigned MaddOpc, unsigned MnegOpc, const TargetRegisterClass *RC) { + assert(IdxMulOpd == 1); + + Register NewVR = + genNeg(MF, MRI, TII, Root, InsInstrs, InstrIdxForVirtReg, MnegOpc, RC); + + return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC, + FMAInstKind::Indexed, &NewVR); +} + /// genMaddR - Generate madd instruction and combine mul and add using /// an extra virtual register /// Example - an ADD intermediate needs to be stored in a register: @@ -4279,6 +4538,231 @@ void AArch64InstrInfo::genAlternativeCodeSequence( } break; } + + case MachineCombinerPattern::MULADDv8i8_OP1: + Opc = AArch64::MLAv8i8; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv8i8_OP2: + Opc = AArch64::MLAv8i8; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv16i8_OP1: + Opc = AArch64::MLAv16i8; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv16i8_OP2: + Opc = AArch64::MLAv16i8; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i16_OP1: + Opc = AArch64::MLAv4i16; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i16_OP2: + Opc = AArch64::MLAv4i16; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv8i16_OP1: + Opc = AArch64::MLAv8i16; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv8i16_OP2: + Opc = AArch64::MLAv8i16; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv2i32_OP1: + Opc = AArch64::MLAv2i32; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv2i32_OP2: + Opc = AArch64::MLAv2i32; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i32_OP1: + Opc = AArch64::MLAv4i32; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i32_OP2: + Opc = AArch64::MLAv4i32; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + + case MachineCombinerPattern::MULSUBv8i8_OP1: + Opc = AArch64::MLAv8i8; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i8, + RC); + break; + case MachineCombinerPattern::MULSUBv8i8_OP2: + Opc = AArch64::MLSv8i8; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv16i8_OP1: + Opc = AArch64::MLAv16i8; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv16i8, + RC); + break; + case MachineCombinerPattern::MULSUBv16i8_OP2: + Opc = AArch64::MLSv16i8; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv4i16_OP1: + Opc = AArch64::MLAv4i16; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i16, + RC); + break; + case MachineCombinerPattern::MULSUBv4i16_OP2: + Opc = AArch64::MLSv4i16; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv8i16_OP1: + Opc = AArch64::MLAv8i16; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i16, + RC); + break; + case MachineCombinerPattern::MULSUBv8i16_OP2: + Opc = AArch64::MLSv8i16; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv2i32_OP1: + Opc = AArch64::MLAv2i32; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv2i32, + RC); + break; + case MachineCombinerPattern::MULSUBv2i32_OP2: + Opc = AArch64::MLSv2i32; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv4i32_OP1: + Opc = AArch64::MLAv4i32; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i32, + RC); + break; + case MachineCombinerPattern::MULSUBv4i32_OP2: + Opc = AArch64::MLSv4i32; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + + case MachineCombinerPattern::MULADDv4i16_indexed_OP1: + Opc = AArch64::MLAv4i16_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i16_indexed_OP2: + Opc = AArch64::MLAv4i16_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv8i16_indexed_OP1: + Opc = AArch64::MLAv8i16_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv8i16_indexed_OP2: + Opc = AArch64::MLAv8i16_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv2i32_indexed_OP1: + Opc = AArch64::MLAv2i32_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv2i32_indexed_OP2: + Opc = AArch64::MLAv2i32_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i32_indexed_OP1: + Opc = AArch64::MLAv4i32_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i32_indexed_OP2: + Opc = AArch64::MLAv4i32_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + + case MachineCombinerPattern::MULSUBv4i16_indexed_OP1: + Opc = AArch64::MLAv4i16_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i16, + RC); + break; + case MachineCombinerPattern::MULSUBv4i16_indexed_OP2: + Opc = AArch64::MLSv4i16_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv8i16_indexed_OP1: + Opc = AArch64::MLAv8i16_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i16, + RC); + break; + case MachineCombinerPattern::MULSUBv8i16_indexed_OP2: + Opc = AArch64::MLSv8i16_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv2i32_indexed_OP1: + Opc = AArch64::MLAv2i32_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv2i32, + RC); + break; + case MachineCombinerPattern::MULSUBv2i32_indexed_OP2: + Opc = AArch64::MLSv2i32_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv4i32_indexed_OP1: + Opc = AArch64::MLAv4i32_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i32, + RC); + break; + case MachineCombinerPattern::MULSUBv4i32_indexed_OP2: + Opc = AArch64::MLSv4i32_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + // Floating Point Support case MachineCombinerPattern::FMULADDH_OP1: Opc = AArch64::FMADDHrrr; @@ -5037,8 +5521,99 @@ AArch64InstrInfo::findRegisterToSaveLRTo(const outliner::Candidate &C) const { return 0u; } -outliner::OutlinedFunction -AArch64InstrInfo::getOutliningCandidateInfo( +static bool +outliningCandidatesSigningScopeConsensus(const outliner::Candidate &a, + const outliner::Candidate &b) { + const Function &Fa = a.getMF()->getFunction(); + const Function &Fb = b.getMF()->getFunction(); + + // If none of the functions have the "sign-return-address" attribute their + // signing behaviour is equal + if (!Fa.hasFnAttribute("sign-return-address") && + !Fb.hasFnAttribute("sign-return-address")) { + return true; + } + + // If both functions have the "sign-return-address" attribute their signing + // behaviour is equal, if the values of the attributes are equal + if (Fa.hasFnAttribute("sign-return-address") && + Fb.hasFnAttribute("sign-return-address")) { + StringRef ScopeA = + Fa.getFnAttribute("sign-return-address").getValueAsString(); + StringRef ScopeB = + Fb.getFnAttribute("sign-return-address").getValueAsString(); + return ScopeA.equals(ScopeB); + } + + // If function B doesn't have the "sign-return-address" attribute but A does, + // the functions' signing behaviour is equal if A's value for + // "sign-return-address" is "none" and vice versa. + if (Fa.hasFnAttribute("sign-return-address")) { + StringRef ScopeA = + Fa.getFnAttribute("sign-return-address").getValueAsString(); + return ScopeA.equals("none"); + } + + if (Fb.hasFnAttribute("sign-return-address")) { + StringRef ScopeB = + Fb.getFnAttribute("sign-return-address").getValueAsString(); + return ScopeB.equals("none"); + } + + llvm_unreachable("Unkown combination of sign-return-address attributes"); +} + +static bool +outliningCandidatesSigningKeyConsensus(const outliner::Candidate &a, + const outliner::Candidate &b) { + const Function &Fa = a.getMF()->getFunction(); + const Function &Fb = b.getMF()->getFunction(); + + // If none of the functions have the "sign-return-address-key" attribute + // their keys are equal + if (!Fa.hasFnAttribute("sign-return-address-key") && + !Fb.hasFnAttribute("sign-return-address-key")) { + return true; + } + + // If both functions have the "sign-return-address-key" attribute their + // keys are equal if the values of "sign-return-address-key" are equal + if (Fa.hasFnAttribute("sign-return-address-key") && + Fb.hasFnAttribute("sign-return-address-key")) { + StringRef KeyA = + Fa.getFnAttribute("sign-return-address-key").getValueAsString(); + StringRef KeyB = + Fb.getFnAttribute("sign-return-address-key").getValueAsString(); + return KeyA.equals(KeyB); + } + + // If B doesn't have the "sign-return-address-key" attribute, both keys are + // equal, if function a has the default key (a_key) + if (Fa.hasFnAttribute("sign-return-address-key")) { + StringRef KeyA = + Fa.getFnAttribute("sign-return-address-key").getValueAsString(); + return KeyA.equals_lower("a_key"); + } + + if (Fb.hasFnAttribute("sign-return-address-key")) { + StringRef KeyB = + Fb.getFnAttribute("sign-return-address-key").getValueAsString(); + return KeyB.equals_lower("a_key"); + } + + llvm_unreachable("Unkown combination of sign-return-address-key attributes"); +} + +static bool outliningCandidatesV8_3OpsConsensus(const outliner::Candidate &a, + const outliner::Candidate &b) { + const AArch64Subtarget &SubtargetA = + a.getMF()->getSubtarget<AArch64Subtarget>(); + const AArch64Subtarget &SubtargetB = + b.getMF()->getSubtarget<AArch64Subtarget>(); + return SubtargetA.hasV8_3aOps() == SubtargetB.hasV8_3aOps(); +} + +outliner::OutlinedFunction AArch64InstrInfo::getOutliningCandidateInfo( std::vector<outliner::Candidate> &RepeatedSequenceLocs) const { outliner::Candidate &FirstCand = RepeatedSequenceLocs[0]; unsigned SequenceSize = @@ -5046,12 +5621,115 @@ AArch64InstrInfo::getOutliningCandidateInfo( [this](unsigned Sum, const MachineInstr &MI) { return Sum + getInstSizeInBytes(MI); }); + unsigned NumBytesToCreateFrame = 0; + + // We only allow outlining for functions having exactly matching return + // address signing attributes, i.e., all share the same value for the + // attribute "sign-return-address" and all share the same type of key they + // are signed with. + // Additionally we require all functions to simultaniously either support + // v8.3a features or not. Otherwise an outlined function could get signed + // using dedicated v8.3 instructions and a call from a function that doesn't + // support v8.3 instructions would therefore be invalid. + if (std::adjacent_find( + RepeatedSequenceLocs.begin(), RepeatedSequenceLocs.end(), + [](const outliner::Candidate &a, const outliner::Candidate &b) { + // Return true if a and b are non-equal w.r.t. return address + // signing or support of v8.3a features + if (outliningCandidatesSigningScopeConsensus(a, b) && + outliningCandidatesSigningKeyConsensus(a, b) && + outliningCandidatesV8_3OpsConsensus(a, b)) { + return false; + } + return true; + }) != RepeatedSequenceLocs.end()) { + return outliner::OutlinedFunction(); + } + + // Since at this point all candidates agree on their return address signing + // picking just one is fine. If the candidate functions potentially sign their + // return addresses, the outlined function should do the same. Note that in + // the case of "sign-return-address"="non-leaf" this is an assumption: It is + // not certainly true that the outlined function will have to sign its return + // address but this decision is made later, when the decision to outline + // has already been made. + // The same holds for the number of additional instructions we need: On + // v8.3a RET can be replaced by RETAA/RETAB and no AUT instruction is + // necessary. However, at this point we don't know if the outlined function + // will have a RET instruction so we assume the worst. + const Function &FCF = FirstCand.getMF()->getFunction(); + const TargetRegisterInfo &TRI = getRegisterInfo(); + if (FCF.hasFnAttribute("sign-return-address")) { + // One PAC and one AUT instructions + NumBytesToCreateFrame += 8; + + // We have to check if sp modifying instructions would get outlined. + // If so we only allow outlining if sp is unchanged overall, so matching + // sub and add instructions are okay to outline, all other sp modifications + // are not + auto hasIllegalSPModification = [&TRI](outliner::Candidate &C) { + int SPValue = 0; + MachineBasicBlock::iterator MBBI = C.front(); + for (;;) { + if (MBBI->modifiesRegister(AArch64::SP, &TRI)) { + switch (MBBI->getOpcode()) { + case AArch64::ADDXri: + case AArch64::ADDWri: + assert(MBBI->getNumOperands() == 4 && "Wrong number of operands"); + assert(MBBI->getOperand(2).isImm() && + "Expected operand to be immediate"); + assert(MBBI->getOperand(1).isReg() && + "Expected operand to be a register"); + // Check if the add just increments sp. If so, we search for + // matching sub instructions that decrement sp. If not, the + // modification is illegal + if (MBBI->getOperand(1).getReg() == AArch64::SP) + SPValue += MBBI->getOperand(2).getImm(); + else + return true; + break; + case AArch64::SUBXri: + case AArch64::SUBWri: + assert(MBBI->getNumOperands() == 4 && "Wrong number of operands"); + assert(MBBI->getOperand(2).isImm() && + "Expected operand to be immediate"); + assert(MBBI->getOperand(1).isReg() && + "Expected operand to be a register"); + // Check if the sub just decrements sp. If so, we search for + // matching add instructions that increment sp. If not, the + // modification is illegal + if (MBBI->getOperand(1).getReg() == AArch64::SP) + SPValue -= MBBI->getOperand(2).getImm(); + else + return true; + break; + default: + return true; + } + } + if (MBBI == C.back()) + break; + ++MBBI; + } + if (SPValue) + return true; + return false; + }; + // Remove candidates with illegal stack modifying instructions + RepeatedSequenceLocs.erase(std::remove_if(RepeatedSequenceLocs.begin(), + RepeatedSequenceLocs.end(), + hasIllegalSPModification), + RepeatedSequenceLocs.end()); + + // If the sequence doesn't have enough candidates left, then we're done. + if (RepeatedSequenceLocs.size() < 2) + return outliner::OutlinedFunction(); + } // Properties about candidate MBBs that hold for all of them. unsigned FlagsSetInAll = 0xF; // Compute liveness information for each candidate, and set FlagsSetInAll. - const TargetRegisterInfo &TRI = getRegisterInfo(); std::for_each(RepeatedSequenceLocs.begin(), RepeatedSequenceLocs.end(), [&FlagsSetInAll](outliner::Candidate &C) { FlagsSetInAll &= C.Flags; @@ -5107,7 +5785,7 @@ AArch64InstrInfo::getOutliningCandidateInfo( }; unsigned FrameID = MachineOutlinerDefault; - unsigned NumBytesToCreateFrame = 4; + NumBytesToCreateFrame += 4; bool HasBTI = any_of(RepeatedSequenceLocs, [](outliner::Candidate &C) { return C.getMF()->getFunction().hasFnAttribute("branch-target-enforcement"); @@ -5190,11 +5868,21 @@ AArch64InstrInfo::getOutliningCandidateInfo( unsigned NumBytesNoStackCalls = 0; std::vector<outliner::Candidate> CandidatesWithoutStackFixups; + // Check if we have to save LR. for (outliner::Candidate &C : RepeatedSequenceLocs) { C.initLRU(TRI); + // If we have a noreturn caller, then we're going to be conservative and + // say that we have to save LR. If we don't have a ret at the end of the + // block, then we can't reason about liveness accurately. + // + // FIXME: We can probably do better than always disabling this in + // noreturn functions by fixing up the liveness info. + bool IsNoReturn = + C.getMF()->getFunction().hasFnAttribute(Attribute::NoReturn); + // Is LR available? If so, we don't need a save. - if (C.LRU.available(AArch64::LR)) { + if (C.LRU.available(AArch64::LR) && !IsNoReturn) { NumBytesNoStackCalls += 4; C.setCallInfo(MachineOutlinerNoLRSave, 4); CandidatesWithoutStackFixups.push_back(C); @@ -5376,6 +6064,19 @@ AArch64InstrInfo::getOutliningType(MachineBasicBlock::iterator &MIT, MachineFunction *MF = MBB->getParent(); AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>(); + // Don't outline anything used for return address signing. The outlined + // function will get signed later if needed + switch (MI.getOpcode()) { + case AArch64::PACIASP: + case AArch64::PACIBSP: + case AArch64::AUTIASP: + case AArch64::AUTIBSP: + case AArch64::RETAA: + case AArch64::RETAB: + case AArch64::EMITBKEY: + return outliner::InstrType::Illegal; + } + // Don't outline LOHs. if (FuncInfo->getLOHRelated().count(&MI)) return outliner::InstrType::Illegal; @@ -5528,6 +6229,59 @@ void AArch64InstrInfo::fixupPostOutline(MachineBasicBlock &MBB) const { } } +static void signOutlinedFunction(MachineFunction &MF, MachineBasicBlock &MBB, + bool ShouldSignReturnAddr, + bool ShouldSignReturnAddrWithAKey) { + if (ShouldSignReturnAddr) { + MachineBasicBlock::iterator MBBPAC = MBB.begin(); + MachineBasicBlock::iterator MBBAUT = MBB.getFirstTerminator(); + const AArch64Subtarget &Subtarget = MF.getSubtarget<AArch64Subtarget>(); + const TargetInstrInfo *TII = Subtarget.getInstrInfo(); + DebugLoc DL; + + if (MBBAUT != MBB.end()) + DL = MBBAUT->getDebugLoc(); + + // At the very beginning of the basic block we insert the following + // depending on the key type + // + // a_key: b_key: + // PACIASP EMITBKEY + // CFI_INSTRUCTION PACIBSP + // CFI_INSTRUCTION + if (ShouldSignReturnAddrWithAKey) { + BuildMI(MBB, MBBPAC, DebugLoc(), TII->get(AArch64::PACIASP)) + .setMIFlag(MachineInstr::FrameSetup); + } else { + BuildMI(MBB, MBBPAC, DebugLoc(), TII->get(AArch64::EMITBKEY)) + .setMIFlag(MachineInstr::FrameSetup); + BuildMI(MBB, MBBPAC, DebugLoc(), TII->get(AArch64::PACIBSP)) + .setMIFlag(MachineInstr::FrameSetup); + } + unsigned CFIIndex = + MF.addFrameInst(MCCFIInstruction::createNegateRAState(nullptr)); + BuildMI(MBB, MBBPAC, DebugLoc(), TII->get(AArch64::CFI_INSTRUCTION)) + .addCFIIndex(CFIIndex) + .setMIFlags(MachineInstr::FrameSetup); + + // If v8.3a features are available we can replace a RET instruction by + // RETAA or RETAB and omit the AUT instructions + if (Subtarget.hasV8_3aOps() && MBBAUT != MBB.end() && + MBBAUT->getOpcode() == AArch64::RET) { + BuildMI(MBB, MBBAUT, DL, + TII->get(ShouldSignReturnAddrWithAKey ? AArch64::RETAA + : AArch64::RETAB)) + .copyImplicitOps(*MBBAUT); + MBB.erase(MBBAUT); + } else { + BuildMI(MBB, MBBAUT, DL, + TII->get(ShouldSignReturnAddrWithAKey ? AArch64::AUTIASP + : AArch64::AUTIBSP)) + .setMIFlag(MachineInstr::FrameDestroy); + } + } +} + void AArch64InstrInfo::buildOutlinedFrame( MachineBasicBlock &MBB, MachineFunction &MF, const outliner::OutlinedFunction &OF) const { @@ -5543,16 +6297,19 @@ void AArch64InstrInfo::buildOutlinedFrame( TailOpcode = AArch64::TCRETURNriALL; } MachineInstr *TC = BuildMI(MF, DebugLoc(), get(TailOpcode)) - .add(Call->getOperand(0)) - .addImm(0); + .add(Call->getOperand(0)) + .addImm(0); MBB.insert(MBB.end(), TC); Call->eraseFromParent(); } + bool IsLeafFunction = true; + // Is there a call in the outlined range? - auto IsNonTailCall = [](MachineInstr &MI) { + auto IsNonTailCall = [](const MachineInstr &MI) { return MI.isCall() && !MI.isReturn(); }; + if (std::any_of(MBB.instr_begin(), MBB.instr_end(), IsNonTailCall)) { // Fix up the instructions in the range, since we're going to modify the // stack. @@ -5560,6 +6317,8 @@ void AArch64InstrInfo::buildOutlinedFrame( "Can only fix up stack references once"); fixupPostOutline(MBB); + IsLeafFunction = false; + // LR has to be a live in so that we can save it. MBB.addLiveIn(AArch64::LR); @@ -5606,16 +6365,47 @@ void AArch64InstrInfo::buildOutlinedFrame( Et = MBB.insert(Et, LDRXpost); } + // If a bunch of candidates reach this point they must agree on their return + // address signing. It is therefore enough to just consider the signing + // behaviour of one of them + const Function &CF = OF.Candidates.front().getMF()->getFunction(); + bool ShouldSignReturnAddr = false; + if (CF.hasFnAttribute("sign-return-address")) { + StringRef Scope = + CF.getFnAttribute("sign-return-address").getValueAsString(); + if (Scope.equals("all")) + ShouldSignReturnAddr = true; + else if (Scope.equals("non-leaf") && !IsLeafFunction) + ShouldSignReturnAddr = true; + } + + // a_key is the default + bool ShouldSignReturnAddrWithAKey = true; + if (CF.hasFnAttribute("sign-return-address-key")) { + const StringRef Key = + CF.getFnAttribute("sign-return-address-key").getValueAsString(); + // Key can either be a_key or b_key + assert((Key.equals_lower("a_key") || Key.equals_lower("b_key")) && + "Return address signing key must be either a_key or b_key"); + ShouldSignReturnAddrWithAKey = Key.equals_lower("a_key"); + } + // If this is a tail call outlined function, then there's already a return. if (OF.FrameConstructionID == MachineOutlinerTailCall || - OF.FrameConstructionID == MachineOutlinerThunk) + OF.FrameConstructionID == MachineOutlinerThunk) { + signOutlinedFunction(MF, MBB, ShouldSignReturnAddr, + ShouldSignReturnAddrWithAKey); return; + } // It's not a tail call, so we have to insert the return ourselves. MachineInstr *ret = BuildMI(MF, DebugLoc(), get(AArch64::RET)) .addReg(AArch64::LR, RegState::Undef); MBB.insert(MBB.end(), ret); + signOutlinedFunction(MF, MBB, ShouldSignReturnAddr, + ShouldSignReturnAddrWithAKey); + // Did we have to modify the stack by saving the link register? if (OF.FrameConstructionID != MachineOutlinerDefault) return; @@ -5702,29 +6492,126 @@ bool AArch64InstrInfo::shouldOutlineFromFunctionByDefault( return MF.getFunction().hasMinSize(); } -bool AArch64InstrInfo::isCopyInstrImpl( - const MachineInstr &MI, const MachineOperand *&Source, - const MachineOperand *&Destination) const { +Optional<DestSourcePair> +AArch64InstrInfo::isCopyInstrImpl(const MachineInstr &MI) const { // AArch64::ORRWrs and AArch64::ORRXrs with WZR/XZR reg // and zero immediate operands used as an alias for mov instruction. if (MI.getOpcode() == AArch64::ORRWrs && MI.getOperand(1).getReg() == AArch64::WZR && MI.getOperand(3).getImm() == 0x0) { - Destination = &MI.getOperand(0); - Source = &MI.getOperand(2); - return true; + return DestSourcePair{MI.getOperand(0), MI.getOperand(2)}; } if (MI.getOpcode() == AArch64::ORRXrs && MI.getOperand(1).getReg() == AArch64::XZR && MI.getOperand(3).getImm() == 0x0) { - Destination = &MI.getOperand(0); - Source = &MI.getOperand(2); - return true; + return DestSourcePair{MI.getOperand(0), MI.getOperand(2)}; } - return false; + return None; +} + +Optional<RegImmPair> AArch64InstrInfo::isAddImmediate(const MachineInstr &MI, + Register Reg) const { + int Sign = 1; + int64_t Offset = 0; + + // TODO: Handle cases where Reg is a super- or sub-register of the + // destination register. + if (Reg != MI.getOperand(0).getReg()) + return None; + + switch (MI.getOpcode()) { + default: + return None; + case AArch64::SUBWri: + case AArch64::SUBXri: + case AArch64::SUBSWri: + case AArch64::SUBSXri: + Sign *= -1; + LLVM_FALLTHROUGH; + case AArch64::ADDSWri: + case AArch64::ADDSXri: + case AArch64::ADDWri: + case AArch64::ADDXri: { + // TODO: Third operand can be global address (usually some string). + if (!MI.getOperand(0).isReg() || !MI.getOperand(1).isReg() || + !MI.getOperand(2).isImm()) + return None; + Offset = MI.getOperand(2).getImm() * Sign; + int Shift = MI.getOperand(3).getImm(); + assert((Shift == 0 || Shift == 12) && "Shift can be either 0 or 12"); + Offset = Offset << Shift; + } + } + return RegImmPair{MI.getOperand(1).getReg(), Offset}; +} + +/// If the given ORR instruction is a copy, and \p DescribedReg overlaps with +/// the destination register then, if possible, describe the value in terms of +/// the source register. +static Optional<ParamLoadedValue> +describeORRLoadedValue(const MachineInstr &MI, Register DescribedReg, + const TargetInstrInfo *TII, + const TargetRegisterInfo *TRI) { + auto DestSrc = TII->isCopyInstr(MI); + if (!DestSrc) + return None; + + Register DestReg = DestSrc->Destination->getReg(); + Register SrcReg = DestSrc->Source->getReg(); + + auto Expr = DIExpression::get(MI.getMF()->getFunction().getContext(), {}); + + // If the described register is the destination, just return the source. + if (DestReg == DescribedReg) + return ParamLoadedValue(MachineOperand::CreateReg(SrcReg, false), Expr); + + // ORRWrs zero-extends to 64-bits, so we need to consider such cases. + if (MI.getOpcode() == AArch64::ORRWrs && + TRI->isSuperRegister(DestReg, DescribedReg)) + return ParamLoadedValue(MachineOperand::CreateReg(SrcReg, false), Expr); + + // We may need to describe the lower part of a ORRXrs move. + if (MI.getOpcode() == AArch64::ORRXrs && + TRI->isSubRegister(DestReg, DescribedReg)) { + Register SrcSubReg = TRI->getSubReg(SrcReg, AArch64::sub_32); + return ParamLoadedValue(MachineOperand::CreateReg(SrcSubReg, false), Expr); + } + + assert(!TRI->isSuperOrSubRegisterEq(DestReg, DescribedReg) && + "Unhandled ORR[XW]rs copy case"); + + return None; +} + +Optional<ParamLoadedValue> +AArch64InstrInfo::describeLoadedValue(const MachineInstr &MI, + Register Reg) const { + const MachineFunction *MF = MI.getMF(); + const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo(); + switch (MI.getOpcode()) { + case AArch64::MOVZWi: + case AArch64::MOVZXi: { + // MOVZWi may be used for producing zero-extended 32-bit immediates in + // 64-bit parameters, so we need to consider super-registers. + if (!TRI->isSuperRegisterEq(MI.getOperand(0).getReg(), Reg)) + return None; + + if (!MI.getOperand(1).isImm()) + return None; + int64_t Immediate = MI.getOperand(1).getImm(); + int Shift = MI.getOperand(2).getImm(); + return ParamLoadedValue(MachineOperand::CreateImm(Immediate << Shift), + nullptr); + } + case AArch64::ORRWrs: + case AArch64::ORRXrs: + return describeORRLoadedValue(MI, Reg, this, TRI); + } + + return TargetInstrInfo::describeLoadedValue(MI, Reg); } #define GET_INSTRINFO_HELPERS |
