diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 1217 |
1 files changed, 1042 insertions, 175 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 2746117e8ee5..d45a80057564 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -10,9 +10,9 @@ // //===----------------------------------------------------------------------===// -#include "AArch64ExpandImm.h" #include "AArch64ISelLowering.h" #include "AArch64CallingConvention.h" +#include "AArch64ExpandImm.h" #include "AArch64MachineFunctionInfo.h" #include "AArch64PerfectShuffle.h" #include "AArch64RegisterInfo.h" @@ -58,6 +58,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/Module.h" #include "llvm/IR/OperandTraits.h" #include "llvm/IR/PatternMatch.h" @@ -178,11 +179,25 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, addRegisterClass(MVT::nxv2f16, &AArch64::ZPRRegClass); addRegisterClass(MVT::nxv4f16, &AArch64::ZPRRegClass); addRegisterClass(MVT::nxv8f16, &AArch64::ZPRRegClass); - addRegisterClass(MVT::nxv1f32, &AArch64::ZPRRegClass); addRegisterClass(MVT::nxv2f32, &AArch64::ZPRRegClass); addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass); - addRegisterClass(MVT::nxv1f64, &AArch64::ZPRRegClass); addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass); + + for (auto VT : { MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64 }) { + setOperationAction(ISD::SADDSAT, VT, Legal); + setOperationAction(ISD::UADDSAT, VT, Legal); + setOperationAction(ISD::SSUBSAT, VT, Legal); + setOperationAction(ISD::USUBSAT, VT, Legal); + setOperationAction(ISD::SMAX, VT, Legal); + setOperationAction(ISD::UMAX, VT, Legal); + setOperationAction(ISD::SMIN, VT, Legal); + setOperationAction(ISD::UMIN, VT, Legal); + } + + for (auto VT : + { MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv4i8, + MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 }) + setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal); } // Compute derived properties from the register classes @@ -422,14 +437,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FSUB, MVT::v4f16, Promote); setOperationAction(ISD::FMUL, MVT::v4f16, Promote); setOperationAction(ISD::FDIV, MVT::v4f16, Promote); - setOperationAction(ISD::FP_EXTEND, MVT::v4f16, Promote); - setOperationAction(ISD::FP_ROUND, MVT::v4f16, Promote); AddPromotedToType(ISD::FADD, MVT::v4f16, MVT::v4f32); AddPromotedToType(ISD::FSUB, MVT::v4f16, MVT::v4f32); AddPromotedToType(ISD::FMUL, MVT::v4f16, MVT::v4f32); AddPromotedToType(ISD::FDIV, MVT::v4f16, MVT::v4f32); - AddPromotedToType(ISD::FP_EXTEND, MVT::v4f16, MVT::v4f32); - AddPromotedToType(ISD::FP_ROUND, MVT::v4f16, MVT::v4f32); setOperationAction(ISD::FABS, MVT::v4f16, Expand); setOperationAction(ISD::FNEG, MVT::v4f16, Expand); @@ -510,6 +521,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i32, Custom); setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i64, Custom); + // 128-bit loads and stores can be done without expanding + setOperationAction(ISD::LOAD, MVT::i128, Custom); + setOperationAction(ISD::STORE, MVT::i128, Custom); + // Lower READCYCLECOUNTER using an mrs from PMCCNTR_EL0. // This requires the Performance Monitors extension. if (Subtarget->hasPerfMon()) @@ -525,6 +540,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FSINCOS, MVT::f32, Expand); } + if (Subtarget->getTargetTriple().isOSMSVCRT()) { + // MSVCRT doesn't have powi; fall back to pow + setLibcallName(RTLIB::POWI_F32, nullptr); + setLibcallName(RTLIB::POWI_F64, nullptr); + } + // Make floating-point constants legal for the large code model, so they don't // become loads from the constant pool. if (Subtarget->isTargetMachO() && TM.getCodeModel() == CodeModel::Large) { @@ -601,7 +622,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::ANY_EXTEND); setTargetDAGCombine(ISD::ZERO_EXTEND); setTargetDAGCombine(ISD::SIGN_EXTEND); - setTargetDAGCombine(ISD::BITCAST); + setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); setTargetDAGCombine(ISD::CONCAT_VECTORS); setTargetDAGCombine(ISD::STORE); if (Subtarget->supportsAddressTopByteIgnored()) @@ -734,14 +755,20 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::MUL, MVT::v4i32, Custom); setOperationAction(ISD::MUL, MVT::v2i64, Custom); - // Vector reductions for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) { + // Vector reductions setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); + + // Saturates + setOperationAction(ISD::SADDSAT, VT, Legal); + setOperationAction(ISD::UADDSAT, VT, Legal); + setOperationAction(ISD::SSUBSAT, VT, Legal); + setOperationAction(ISD::USUBSAT, VT, Legal); } for (MVT VT : { MVT::v4f16, MVT::v2f32, MVT::v8f16, MVT::v4f32, MVT::v2f64 }) { @@ -802,10 +829,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, } if (Subtarget->hasSVE()) { + // FIXME: Add custom lowering of MLOAD to handle different passthrus (not a + // splat of 0 or undef) once vector selects supported in SVE codegen. See + // D68877 for more details. for (MVT VT : MVT::integer_scalable_vector_valuetypes()) { - if (isTypeLegal(VT) && VT.getVectorElementType() != MVT::i1) + if (isTypeLegal(VT)) setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); } + setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); + setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom); } PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive(); @@ -1257,6 +1289,19 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { case AArch64ISD::UMINV: return "AArch64ISD::UMINV"; case AArch64ISD::SMAXV: return "AArch64ISD::SMAXV"; case AArch64ISD::UMAXV: return "AArch64ISD::UMAXV"; + case AArch64ISD::SMAXV_PRED: return "AArch64ISD::SMAXV_PRED"; + case AArch64ISD::UMAXV_PRED: return "AArch64ISD::UMAXV_PRED"; + case AArch64ISD::SMINV_PRED: return "AArch64ISD::SMINV_PRED"; + case AArch64ISD::UMINV_PRED: return "AArch64ISD::UMINV_PRED"; + case AArch64ISD::ORV_PRED: return "AArch64ISD::ORV_PRED"; + case AArch64ISD::EORV_PRED: return "AArch64ISD::EORV_PRED"; + case AArch64ISD::ANDV_PRED: return "AArch64ISD::ANDV_PRED"; + case AArch64ISD::CLASTA_N: return "AArch64ISD::CLASTA_N"; + case AArch64ISD::CLASTB_N: return "AArch64ISD::CLASTB_N"; + case AArch64ISD::LASTA: return "AArch64ISD::LASTA"; + case AArch64ISD::LASTB: return "AArch64ISD::LASTB"; + case AArch64ISD::REV: return "AArch64ISD::REV"; + case AArch64ISD::TBL: return "AArch64ISD::TBL"; case AArch64ISD::NOT: return "AArch64ISD::NOT"; case AArch64ISD::BIT: return "AArch64ISD::BIT"; case AArch64ISD::CBZ: return "AArch64ISD::CBZ"; @@ -1311,6 +1356,32 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { case AArch64ISD::SUNPKLO: return "AArch64ISD::SUNPKLO"; case AArch64ISD::UUNPKHI: return "AArch64ISD::UUNPKHI"; case AArch64ISD::UUNPKLO: return "AArch64ISD::UUNPKLO"; + case AArch64ISD::INSR: return "AArch64ISD::INSR"; + case AArch64ISD::PTEST: return "AArch64ISD::PTEST"; + case AArch64ISD::PTRUE: return "AArch64ISD::PTRUE"; + case AArch64ISD::GLD1: return "AArch64ISD::GLD1"; + case AArch64ISD::GLD1_SCALED: return "AArch64ISD::GLD1_SCALED"; + case AArch64ISD::GLD1_SXTW: return "AArch64ISD::GLD1_SXTW"; + case AArch64ISD::GLD1_UXTW: return "AArch64ISD::GLD1_UXTW"; + case AArch64ISD::GLD1_SXTW_SCALED: return "AArch64ISD::GLD1_SXTW_SCALED"; + case AArch64ISD::GLD1_UXTW_SCALED: return "AArch64ISD::GLD1_UXTW_SCALED"; + case AArch64ISD::GLD1_IMM: return "AArch64ISD::GLD1_IMM"; + case AArch64ISD::GLD1S: return "AArch64ISD::GLD1S"; + case AArch64ISD::GLD1S_SCALED: return "AArch64ISD::GLD1S_SCALED"; + case AArch64ISD::GLD1S_SXTW: return "AArch64ISD::GLD1S_SXTW"; + case AArch64ISD::GLD1S_UXTW: return "AArch64ISD::GLD1S_UXTW"; + case AArch64ISD::GLD1S_SXTW_SCALED: return "AArch64ISD::GLD1S_SXTW_SCALED"; + case AArch64ISD::GLD1S_UXTW_SCALED: return "AArch64ISD::GLD1S_UXTW_SCALED"; + case AArch64ISD::GLD1S_IMM: return "AArch64ISD::GLD1S_IMM"; + case AArch64ISD::SST1: return "AArch64ISD::SST1"; + case AArch64ISD::SST1_SCALED: return "AArch64ISD::SST1_SCALED"; + case AArch64ISD::SST1_SXTW: return "AArch64ISD::SST1_SXTW"; + case AArch64ISD::SST1_UXTW: return "AArch64ISD::SST1_UXTW"; + case AArch64ISD::SST1_SXTW_SCALED: return "AArch64ISD::SST1_SXTW_SCALED"; + case AArch64ISD::SST1_UXTW_SCALED: return "AArch64ISD::SST1_UXTW_SCALED"; + case AArch64ISD::SST1_IMM: return "AArch64ISD::SST1_IMM"; + case AArch64ISD::LDP: return "AArch64ISD::LDP"; + case AArch64ISD::STP: return "AArch64ISD::STP"; } return nullptr; } @@ -1568,7 +1639,8 @@ static void changeVectorFPCCToAArch64CC(ISD::CondCode CC, // All of the compare-mask comparisons are ordered, but we can switch // between the two by a double inversion. E.g. ULE == !OGT. Invert = true; - changeFPCCToAArch64CC(getSetCCInverse(CC, false), CondCode, CondCode2); + changeFPCCToAArch64CC(getSetCCInverse(CC, /* FP inverse */ MVT::f32), + CondCode, CondCode2); break; } } @@ -1815,7 +1887,7 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val, ISD::CondCode CC = cast<CondCodeSDNode>(Val->getOperand(2))->get(); bool isInteger = LHS.getValueType().isInteger(); if (Negate) - CC = getSetCCInverse(CC, isInteger); + CC = getSetCCInverse(CC, LHS.getValueType()); SDLoc DL(Val); // Determine OutCC and handle FP special case. if (isInteger) { @@ -2287,7 +2359,7 @@ static SDValue LowerXOR(SDValue Op, SelectionDAG &DAG) { if (CTVal->isAllOnesValue() && CFVal->isNullValue()) { std::swap(TVal, FVal); std::swap(CTVal, CFVal); - CC = ISD::getSetCCInverse(CC, true); + CC = ISD::getSetCCInverse(CC, LHS.getValueType()); } // If the constants line up, perform the transform! @@ -2861,6 +2933,55 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::aarch64_sve_uunpklo: return DAG.getNode(AArch64ISD::UUNPKLO, dl, Op.getValueType(), Op.getOperand(1)); + case Intrinsic::aarch64_sve_clasta_n: + return DAG.getNode(AArch64ISD::CLASTA_N, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); + case Intrinsic::aarch64_sve_clastb_n: + return DAG.getNode(AArch64ISD::CLASTB_N, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); + case Intrinsic::aarch64_sve_lasta: + return DAG.getNode(AArch64ISD::LASTA, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_lastb: + return DAG.getNode(AArch64ISD::LASTB, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_rev: + return DAG.getNode(AArch64ISD::REV, dl, Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_tbl: + return DAG.getNode(AArch64ISD::TBL, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_trn1: + return DAG.getNode(AArch64ISD::TRN1, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_trn2: + return DAG.getNode(AArch64ISD::TRN2, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_uzp1: + return DAG.getNode(AArch64ISD::UZP1, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_uzp2: + return DAG.getNode(AArch64ISD::UZP2, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_zip1: + return DAG.getNode(AArch64ISD::ZIP1, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_zip2: + return DAG.getNode(AArch64ISD::ZIP2, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_ptrue: + return DAG.getNode(AArch64ISD::PTRUE, dl, Op.getValueType(), + Op.getOperand(1)); + + case Intrinsic::aarch64_sve_insr: { + SDValue Scalar = Op.getOperand(2); + EVT ScalarTy = Scalar.getValueType(); + if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16)) + Scalar = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Scalar); + + return DAG.getNode(AArch64ISD::INSR, dl, Op.getValueType(), + Op.getOperand(1), Scalar); + } case Intrinsic::localaddress: { const auto &MF = DAG.getMachineFunction(); @@ -2886,6 +3007,10 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, } } +bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { + return ExtVal.getValueType().isScalableVector(); +} + // Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16. static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST, EVT VT, EVT MemVT, @@ -2920,7 +3045,7 @@ static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST, // Custom lowering for any store, vector or scalar and/or default or with // a truncate operations. Currently only custom lower truncate operation -// from vector v4i16 to v4i8. +// from vector v4i16 to v4i8 or volatile stores of i128. SDValue AArch64TargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const { SDLoc Dl(Op); @@ -2932,18 +3057,32 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op, EVT VT = Value.getValueType(); EVT MemVT = StoreNode->getMemoryVT(); - assert (VT.isVector() && "Can only custom lower vector store types"); - - unsigned AS = StoreNode->getAddressSpace(); - unsigned Align = StoreNode->getAlignment(); - if (Align < MemVT.getStoreSize() && - !allowsMisalignedMemoryAccesses( - MemVT, AS, Align, StoreNode->getMemOperand()->getFlags(), nullptr)) { - return scalarizeVectorStore(StoreNode, DAG); - } + if (VT.isVector()) { + unsigned AS = StoreNode->getAddressSpace(); + unsigned Align = StoreNode->getAlignment(); + if (Align < MemVT.getStoreSize() && + !allowsMisalignedMemoryAccesses(MemVT, AS, Align, + StoreNode->getMemOperand()->getFlags(), + nullptr)) { + return scalarizeVectorStore(StoreNode, DAG); + } - if (StoreNode->isTruncatingStore()) { - return LowerTruncateVectorStore(Dl, StoreNode, VT, MemVT, DAG); + if (StoreNode->isTruncatingStore()) { + return LowerTruncateVectorStore(Dl, StoreNode, VT, MemVT, DAG); + } + } else if (MemVT == MVT::i128 && StoreNode->isVolatile()) { + assert(StoreNode->getValue()->getValueType(0) == MVT::i128); + SDValue Lo = + DAG.getNode(ISD::EXTRACT_ELEMENT, Dl, MVT::i64, StoreNode->getValue(), + DAG.getConstant(0, Dl, MVT::i64)); + SDValue Hi = + DAG.getNode(ISD::EXTRACT_ELEMENT, Dl, MVT::i64, StoreNode->getValue(), + DAG.getConstant(1, Dl, MVT::i64)); + SDValue Result = DAG.getMemIntrinsicNode( + AArch64ISD::STP, Dl, DAG.getVTList(MVT::Other), + {StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()}, + StoreNode->getMemoryVT(), StoreNode->getMemOperand()); + return Result; } return SDValue(); @@ -3092,6 +3231,9 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC, switch (CC) { default: report_fatal_error("Unsupported calling convention."); + case CallingConv::AArch64_SVE_VectorCall: + // Calling SVE functions is currently not yet supported. + report_fatal_error("Unsupported calling convention."); case CallingConv::WebKit_JS: return CC_AArch64_WebKit_JS; case CallingConv::GHC: @@ -3111,8 +3253,10 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC, : CC_AArch64_DarwinPCS_VarArg; case CallingConv::Win64: return IsVarArg ? CC_AArch64_Win64_VarArg : CC_AArch64_AAPCS; - case CallingConv::AArch64_VectorCall: - return CC_AArch64_AAPCS; + case CallingConv::CFGuard_Check: + return CC_AArch64_Win64_CFGuard_Check; + case CallingConv::AArch64_VectorCall: + return CC_AArch64_AAPCS; } } @@ -3848,11 +3992,10 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, } // Walk the register/memloc assignments, inserting copies/loads. - for (unsigned i = 0, realArgIdx = 0, e = ArgLocs.size(); i != e; - ++i, ++realArgIdx) { + for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { CCValAssign &VA = ArgLocs[i]; - SDValue Arg = OutVals[realArgIdx]; - ISD::ArgFlagsTy Flags = Outs[realArgIdx].Flags; + SDValue Arg = OutVals[i]; + ISD::ArgFlagsTy Flags = Outs[i].Flags; // Promote the value if needed. switch (VA.getLocInfo()) { @@ -3867,7 +4010,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, VA.getLocVT(), Arg); break; case CCValAssign::AExt: - if (Outs[realArgIdx].ArgVT == MVT::i1) { + if (Outs[i].ArgVT == MVT::i1) { // AAPCS requires i1 to be zero-extended to 8-bits by the caller. Arg = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Arg); Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i8, Arg); @@ -3896,7 +4039,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, } if (VA.isRegLoc()) { - if (realArgIdx == 0 && Flags.isReturned() && !Flags.isSwiftSelf() && + if (i == 0 && Flags.isReturned() && !Flags.isSwiftSelf() && Outs[0].VT == MVT::i64) { assert(VA.getLocVT() == MVT::i64 && "unexpected calling convention register assignment"); @@ -4014,14 +4157,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // node so that legalize doesn't hack it. if (auto *G = dyn_cast<GlobalAddressSDNode>(Callee)) { auto GV = G->getGlobal(); - if (Subtarget->classifyGlobalFunctionReference(GV, getTargetMachine()) == - AArch64II::MO_GOT) { - Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, AArch64II::MO_GOT); + unsigned OpFlags = + Subtarget->classifyGlobalFunctionReference(GV, getTargetMachine()); + if (OpFlags & AArch64II::MO_GOT) { + Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, OpFlags); Callee = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, Callee); - } else if (Subtarget->isTargetCOFF() && GV->hasDLLImportStorageClass()) { - assert(Subtarget->isTargetWindows() && - "Windows is the only supported COFF target"); - Callee = getGOT(G, DAG, AArch64II::MO_DLLIMPORT); } else { const GlobalValue *GV = G->getGlobal(); Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, 0); @@ -4456,6 +4596,97 @@ AArch64TargetLowering::LowerDarwinGlobalTLSAddress(SDValue Op, return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Chain.getValue(1)); } +/// Convert a thread-local variable reference into a sequence of instructions to +/// compute the variable's address for the local exec TLS model of ELF targets. +/// The sequence depends on the maximum TLS area size. +SDValue AArch64TargetLowering::LowerELFTLSLocalExec(const GlobalValue *GV, + SDValue ThreadBase, + const SDLoc &DL, + SelectionDAG &DAG) const { + EVT PtrVT = getPointerTy(DAG.getDataLayout()); + SDValue TPOff, Addr; + + switch (DAG.getTarget().Options.TLSSize) { + default: + llvm_unreachable("Unexpected TLS size"); + + case 12: { + // mrs x0, TPIDR_EL0 + // add x0, x0, :tprel_lo12:a + SDValue Var = DAG.getTargetGlobalAddress( + GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_PAGEOFF); + return SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, ThreadBase, + Var, + DAG.getTargetConstant(0, DL, MVT::i32)), + 0); + } + + case 24: { + // mrs x0, TPIDR_EL0 + // add x0, x0, :tprel_hi12:a + // add x0, x0, :tprel_lo12_nc:a + SDValue HiVar = DAG.getTargetGlobalAddress( + GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_HI12); + SDValue LoVar = DAG.getTargetGlobalAddress( + GV, DL, PtrVT, 0, + AArch64II::MO_TLS | AArch64II::MO_PAGEOFF | AArch64II::MO_NC); + Addr = SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, ThreadBase, + HiVar, + DAG.getTargetConstant(0, DL, MVT::i32)), + 0); + return SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, Addr, + LoVar, + DAG.getTargetConstant(0, DL, MVT::i32)), + 0); + } + + case 32: { + // mrs x1, TPIDR_EL0 + // movz x0, #:tprel_g1:a + // movk x0, #:tprel_g0_nc:a + // add x0, x1, x0 + SDValue HiVar = DAG.getTargetGlobalAddress( + GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_G1); + SDValue LoVar = DAG.getTargetGlobalAddress( + GV, DL, PtrVT, 0, + AArch64II::MO_TLS | AArch64II::MO_G0 | AArch64II::MO_NC); + TPOff = SDValue(DAG.getMachineNode(AArch64::MOVZXi, DL, PtrVT, HiVar, + DAG.getTargetConstant(16, DL, MVT::i32)), + 0); + TPOff = SDValue(DAG.getMachineNode(AArch64::MOVKXi, DL, PtrVT, TPOff, LoVar, + DAG.getTargetConstant(0, DL, MVT::i32)), + 0); + return DAG.getNode(ISD::ADD, DL, PtrVT, ThreadBase, TPOff); + } + + case 48: { + // mrs x1, TPIDR_EL0 + // movz x0, #:tprel_g2:a + // movk x0, #:tprel_g1_nc:a + // movk x0, #:tprel_g0_nc:a + // add x0, x1, x0 + SDValue HiVar = DAG.getTargetGlobalAddress( + GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_G2); + SDValue MiVar = DAG.getTargetGlobalAddress( + GV, DL, PtrVT, 0, + AArch64II::MO_TLS | AArch64II::MO_G1 | AArch64II::MO_NC); + SDValue LoVar = DAG.getTargetGlobalAddress( + GV, DL, PtrVT, 0, + AArch64II::MO_TLS | AArch64II::MO_G0 | AArch64II::MO_NC); + TPOff = SDValue(DAG.getMachineNode(AArch64::MOVZXi, DL, PtrVT, HiVar, + DAG.getTargetConstant(32, DL, MVT::i32)), + 0); + TPOff = SDValue(DAG.getMachineNode(AArch64::MOVKXi, DL, PtrVT, TPOff, MiVar, + DAG.getTargetConstant(16, DL, MVT::i32)), + 0); + TPOff = SDValue(DAG.getMachineNode(AArch64::MOVKXi, DL, PtrVT, TPOff, LoVar, + DAG.getTargetConstant(0, DL, MVT::i32)), + 0); + return DAG.getNode(ISD::ADD, DL, PtrVT, ThreadBase, TPOff); + } + } +} + /// When accessing thread-local variables under either the general-dynamic or /// local-dynamic system, we make a "TLS-descriptor" call. The variable will /// have a descriptor, accessible via a PC-relative ADRP, and whose first entry @@ -4493,15 +4724,7 @@ SDValue AArch64TargetLowering::LowerELFGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const { assert(Subtarget->isTargetELF() && "This function expects an ELF target"); - if (getTargetMachine().getCodeModel() == CodeModel::Large) - report_fatal_error("ELF TLS only supported in small memory model"); - // Different choices can be made for the maximum size of the TLS area for a - // module. For the small address model, the default TLS size is 16MiB and the - // maximum TLS size is 4GiB. - // FIXME: add -mtls-size command line option and make it control the 16MiB - // vs. 4GiB code sequence generation. - // FIXME: add tiny codemodel support. We currently generate the same code as - // small, which may be larger than needed. + const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op); TLSModel::Model Model = getTargetMachine().getTLSModel(GA->getGlobal()); @@ -4511,6 +4734,17 @@ AArch64TargetLowering::LowerELFGlobalTLSAddress(SDValue Op, Model = TLSModel::GeneralDynamic; } + if (getTargetMachine().getCodeModel() == CodeModel::Large && + Model != TLSModel::LocalExec) + report_fatal_error("ELF TLS only supported in small memory model or " + "in local exec TLS model"); + // Different choices can be made for the maximum size of the TLS area for a + // module. For the small address model, the default TLS size is 16MiB and the + // maximum TLS size is 4GiB. + // FIXME: add tiny and large code model support for TLS access models other + // than local exec. We currently generate the same code as small for tiny, + // which may be larger than needed. + SDValue TPOff; EVT PtrVT = getPointerTy(DAG.getDataLayout()); SDLoc DL(Op); @@ -4519,23 +4753,7 @@ AArch64TargetLowering::LowerELFGlobalTLSAddress(SDValue Op, SDValue ThreadBase = DAG.getNode(AArch64ISD::THREAD_POINTER, DL, PtrVT); if (Model == TLSModel::LocalExec) { - SDValue HiVar = DAG.getTargetGlobalAddress( - GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_HI12); - SDValue LoVar = DAG.getTargetGlobalAddress( - GV, DL, PtrVT, 0, - AArch64II::MO_TLS | AArch64II::MO_PAGEOFF | AArch64II::MO_NC); - - SDValue TPWithOff_lo = - SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, ThreadBase, - HiVar, - DAG.getTargetConstant(0, DL, MVT::i32)), - 0); - SDValue TPWithOff = - SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, TPWithOff_lo, - LoVar, - DAG.getTargetConstant(0, DL, MVT::i32)), - 0); - return TPWithOff; + return LowerELFTLSLocalExec(GV, ThreadBase, DL, DAG); } else if (Model == TLSModel::InitialExec) { TPOff = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, AArch64II::MO_TLS); TPOff = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, TPOff); @@ -4961,8 +5179,8 @@ SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { if (LHS.getValueType().isInteger()) { SDValue CCVal; - SDValue Cmp = - getAArch64Cmp(LHS, RHS, ISD::getSetCCInverse(CC, true), CCVal, DAG, dl); + SDValue Cmp = getAArch64Cmp( + LHS, RHS, ISD::getSetCCInverse(CC, LHS.getValueType()), CCVal, DAG, dl); // Note that we inverted the condition above, so we reverse the order of // the true and false operands here. This will allow the setcc to be @@ -4981,7 +5199,8 @@ SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { AArch64CC::CondCode CC1, CC2; changeFPCCToAArch64CC(CC, CC1, CC2); if (CC2 == AArch64CC::AL) { - changeFPCCToAArch64CC(ISD::getSetCCInverse(CC, false), CC1, CC2); + changeFPCCToAArch64CC(ISD::getSetCCInverse(CC, LHS.getValueType()), CC1, + CC2); SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32); // Note that we inverted the condition above, so we reverse the order of @@ -5042,18 +5261,18 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, if (CTVal && CFVal && CTVal->isAllOnesValue() && CFVal->isNullValue()) { std::swap(TVal, FVal); std::swap(CTVal, CFVal); - CC = ISD::getSetCCInverse(CC, true); + CC = ISD::getSetCCInverse(CC, LHS.getValueType()); } else if (CTVal && CFVal && CTVal->isOne() && CFVal->isNullValue()) { std::swap(TVal, FVal); std::swap(CTVal, CFVal); - CC = ISD::getSetCCInverse(CC, true); + CC = ISD::getSetCCInverse(CC, LHS.getValueType()); } else if (TVal.getOpcode() == ISD::XOR) { // If TVal is a NOT we want to swap TVal and FVal so that we can match // with a CSINV rather than a CSEL. if (isAllOnesConstant(TVal.getOperand(1))) { std::swap(TVal, FVal); std::swap(CTVal, CFVal); - CC = ISD::getSetCCInverse(CC, true); + CC = ISD::getSetCCInverse(CC, LHS.getValueType()); } } else if (TVal.getOpcode() == ISD::SUB) { // If TVal is a negation (SUB from 0) we want to swap TVal and FVal so @@ -5061,7 +5280,7 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, if (isNullConstant(TVal.getOperand(0))) { std::swap(TVal, FVal); std::swap(CTVal, CFVal); - CC = ISD::getSetCCInverse(CC, true); + CC = ISD::getSetCCInverse(CC, LHS.getValueType()); } } else if (CTVal && CFVal) { const int64_t TrueVal = CTVal->getSExtValue(); @@ -5104,7 +5323,7 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, if (Swap) { std::swap(TVal, FVal); std::swap(CTVal, CFVal); - CC = ISD::getSetCCInverse(CC, true); + CC = ISD::getSetCCInverse(CC, LHS.getValueType()); } if (Opcode != AArch64ISD::CSEL) { @@ -5531,7 +5750,7 @@ SDValue AArch64TargetLowering::LowerSPONENTRY(SDValue Op, // FIXME? Maybe this could be a TableGen attribute on some registers and // this table could be generated automatically from RegInfo. Register AArch64TargetLowering:: -getRegisterByName(const char* RegName, EVT VT, const MachineFunction &MF) const { +getRegisterByName(const char* RegName, LLT VT, const MachineFunction &MF) const { Register Reg = MatchRegisterName(RegName); if (AArch64::X1 <= Reg && Reg <= AArch64::X28) { const MCRegisterInfo *MRI = Subtarget->getRegisterInfo(); @@ -6946,19 +7165,55 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, // Otherwise, duplicate from the lane of the input vector. unsigned Opcode = getDUPLANEOp(V1.getValueType().getVectorElementType()); - // SelectionDAGBuilder may have "helpfully" already extracted or conatenated - // to make a vector of the same size as this SHUFFLE. We can ignore the - // extract entirely, and canonicalise the concat using WidenVector. - if (V1.getOpcode() == ISD::EXTRACT_SUBVECTOR) { - Lane += cast<ConstantSDNode>(V1.getOperand(1))->getZExtValue(); + // Try to eliminate a bitcasted extract subvector before a DUPLANE. + auto getScaledOffsetDup = [](SDValue BitCast, int &LaneC, MVT &CastVT) { + // Match: dup (bitcast (extract_subv X, C)), LaneC + if (BitCast.getOpcode() != ISD::BITCAST || + BitCast.getOperand(0).getOpcode() != ISD::EXTRACT_SUBVECTOR) + return false; + + // The extract index must align in the destination type. That may not + // happen if the bitcast is from narrow to wide type. + SDValue Extract = BitCast.getOperand(0); + unsigned ExtIdx = Extract.getConstantOperandVal(1); + unsigned SrcEltBitWidth = Extract.getScalarValueSizeInBits(); + unsigned ExtIdxInBits = ExtIdx * SrcEltBitWidth; + unsigned CastedEltBitWidth = BitCast.getScalarValueSizeInBits(); + if (ExtIdxInBits % CastedEltBitWidth != 0) + return false; + + // Update the lane value by offsetting with the scaled extract index. + LaneC += ExtIdxInBits / CastedEltBitWidth; + + // Determine the casted vector type of the wide vector input. + // dup (bitcast (extract_subv X, C)), LaneC --> dup (bitcast X), LaneC' + // Examples: + // dup (bitcast (extract_subv v2f64 X, 1) to v2f32), 1 --> dup v4f32 X, 3 + // dup (bitcast (extract_subv v16i8 X, 8) to v4i16), 1 --> dup v8i16 X, 5 + unsigned SrcVecNumElts = + Extract.getOperand(0).getValueSizeInBits() / CastedEltBitWidth; + CastVT = MVT::getVectorVT(BitCast.getSimpleValueType().getScalarType(), + SrcVecNumElts); + return true; + }; + MVT CastVT; + if (getScaledOffsetDup(V1, Lane, CastVT)) { + V1 = DAG.getBitcast(CastVT, V1.getOperand(0).getOperand(0)); + } else if (V1.getOpcode() == ISD::EXTRACT_SUBVECTOR) { + // The lane is incremented by the index of the extract. + // Example: dup v2f32 (extract v4f32 X, 2), 1 --> dup v4f32 X, 3 + Lane += V1.getConstantOperandVal(1); V1 = V1.getOperand(0); } else if (V1.getOpcode() == ISD::CONCAT_VECTORS) { + // The lane is decremented if we are splatting from the 2nd operand. + // Example: dup v4i32 (concat v2i32 X, v2i32 Y), 3 --> dup v4i32 Y, 1 unsigned Idx = Lane >= (int)VT.getVectorNumElements() / 2; Lane -= Idx * VT.getVectorNumElements() / 2; V1 = WidenVector(V1.getOperand(Idx), DAG); - } else if (VT.getSizeInBits() == 64) + } else if (VT.getSizeInBits() == 64) { + // Widen the operand to 128-bit register with undef. V1 = WidenVector(V1, DAG); - + } return DAG.getNode(Opcode, dl, VT, V1, DAG.getConstant(Lane, dl, MVT::i64)); } @@ -7077,26 +7332,31 @@ SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op, switch (ElemVT.getSimpleVT().SimpleTy) { case MVT::i8: case MVT::i16: + case MVT::i32: SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i32); - break; + return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); case MVT::i64: SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); - break; - case MVT::i32: - // Fine as is - break; - // TODO: we can support splats of i1s and float types, but haven't added - // patterns yet. - case MVT::i1: + return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); + case MVT::i1: { + // The general case of i1. There isn't any natural way to do this, + // so we use some trickery with whilelo. + // TODO: Add special cases for splat of constant true/false. + SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); + SplatVal = DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, MVT::i64, SplatVal, + DAG.getValueType(MVT::i1)); + SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, + MVT::i64); + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, + DAG.getConstant(0, dl, MVT::i64), SplatVal); + } + // TODO: we can support float types, but haven't added patterns yet. case MVT::f16: case MVT::f32: case MVT::f64: default: - llvm_unreachable("Unsupported SPLAT_VECTOR input operand type"); - break; + report_fatal_error("Unsupported SPLAT_VECTOR input operand type"); } - - return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); } static bool resolveBuildVector(BuildVectorSDNode *BVN, APInt &CnstBits, @@ -8443,6 +8703,26 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, Info.align = Align(16); Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MOVolatile; return true; + case Intrinsic::aarch64_sve_ldnt1: { + PointerType *PtrTy = cast<PointerType>(I.getArgOperand(1)->getType()); + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::getVT(PtrTy->getElementType()); + Info.ptrVal = I.getArgOperand(1); + Info.offset = 0; + Info.align = MaybeAlign(DL.getABITypeAlignment(PtrTy->getElementType())); + Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MONonTemporal; + return true; + } + case Intrinsic::aarch64_sve_stnt1: { + PointerType *PtrTy = cast<PointerType>(I.getArgOperand(2)->getType()); + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::getVT(PtrTy->getElementType()); + Info.ptrVal = I.getArgOperand(2); + Info.offset = 0; + Info.align = MaybeAlign(DL.getABITypeAlignment(PtrTy->getElementType())); + Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MONonTemporal; + return true; + } default: break; } @@ -8515,11 +8795,12 @@ bool AArch64TargetLowering::isProfitableToHoist(Instruction *I) const { return true; const TargetOptions &Options = getTargetMachine().Options; - const DataLayout &DL = I->getModule()->getDataLayout(); - EVT VT = getValueType(DL, User->getOperand(0)->getType()); + const Function *F = I->getFunction(); + const DataLayout &DL = F->getParent()->getDataLayout(); + Type *Ty = User->getOperand(0)->getType(); - return !(isFMAFasterThanFMulAndFAdd(VT) && - isOperationLegalOrCustom(ISD::FMA, VT) && + return !(isFMAFasterThanFMulAndFAdd(*F, Ty) && + isOperationLegalOrCustom(ISD::FMA, getValueType(DL, Ty)) && (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath)); } @@ -9176,7 +9457,8 @@ int AArch64TargetLowering::getScalingFactorCost(const DataLayout &DL, return -1; } -bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(EVT VT) const { +bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd( + const MachineFunction &MF, EVT VT) const { VT = VT.getScalarType(); if (!VT.isSimple()) @@ -9193,6 +9475,17 @@ bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(EVT VT) const { return false; } +bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(const Function &F, + Type *Ty) const { + switch (Ty->getScalarType()->getTypeID()) { + case Type::FloatTyID: + case Type::DoubleTyID: + return true; + default: + return false; + } +} + const MCPhysReg * AArch64TargetLowering::getScratchRegisters(CallingConv::ID) const { // LR is a callee-save register, but we must treat it as clobbered by any call @@ -9363,6 +9656,19 @@ AArch64TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor, return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), SRA); } +static bool IsSVECntIntrinsic(SDValue S) { + switch(getIntrinsicID(S.getNode())) { + default: + break; + case Intrinsic::aarch64_sve_cntb: + case Intrinsic::aarch64_sve_cnth: + case Intrinsic::aarch64_sve_cntw: + case Intrinsic::aarch64_sve_cntd: + return true; + } + return false; +} + static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -9373,9 +9679,18 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG, if (!isa<ConstantSDNode>(N->getOperand(1))) return SDValue(); + SDValue N0 = N->getOperand(0); ConstantSDNode *C = cast<ConstantSDNode>(N->getOperand(1)); const APInt &ConstValue = C->getAPIntValue(); + // Allow the scaling to be folded into the `cnt` instruction by preventing + // the scaling to be obscured here. This makes it easier to pattern match. + if (IsSVECntIntrinsic(N0) || + (N0->getOpcode() == ISD::TRUNCATE && + (IsSVECntIntrinsic(N0->getOperand(0))))) + if (ConstValue.sge(1) && ConstValue.sle(16)) + return SDValue(); + // Multiplication of a power of two plus/minus one can be done more // cheaply as as shift+add/sub. For now, this is true unilaterally. If // future CPUs have a cheaper MADD instruction, this may need to be @@ -9386,7 +9701,6 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG, // e.g. 6=3*2=(2+1)*2. // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 // which equals to (1+2)*16-(1+2). - SDValue N0 = N->getOperand(0); // TrailingZeroes is used to test if the mul can be lowered to // shift+add+shift. unsigned TrailingZeroes = ConstValue.countTrailingZeros(); @@ -9821,6 +10135,67 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, return SDValue(); } +static bool isConstantSplatVectorMaskForType(SDNode *N, EVT MemVT) { + if (!MemVT.getVectorElementType().isSimple()) + return false; + + uint64_t MaskForTy = 0ull; + switch (MemVT.getVectorElementType().getSimpleVT().SimpleTy) { + case MVT::i8: + MaskForTy = 0xffull; + break; + case MVT::i16: + MaskForTy = 0xffffull; + break; + case MVT::i32: + MaskForTy = 0xffffffffull; + break; + default: + return false; + break; + } + + if (N->getOpcode() == AArch64ISD::DUP || N->getOpcode() == ISD::SPLAT_VECTOR) + if (auto *Op0 = dyn_cast<ConstantSDNode>(N->getOperand(0))) + return Op0->getAPIntValue().getLimitedValue() == MaskForTy; + + return false; +} + +static SDValue performSVEAndCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + + SDValue Src = N->getOperand(0); + SDValue Mask = N->getOperand(1); + + if (!Src.hasOneUse()) + return SDValue(); + + // GLD1* instructions perform an implicit zero-extend, which makes them + // perfect candidates for combining. + switch (Src->getOpcode()) { + case AArch64ISD::GLD1: + case AArch64ISD::GLD1_SCALED: + case AArch64ISD::GLD1_SXTW: + case AArch64ISD::GLD1_SXTW_SCALED: + case AArch64ISD::GLD1_UXTW: + case AArch64ISD::GLD1_UXTW_SCALED: + case AArch64ISD::GLD1_IMM: + break; + default: + return SDValue(); + } + + EVT MemVT = cast<VTSDNode>(Src->getOperand(4))->getVT(); + + if (isConstantSplatVectorMaskForType(Mask.getNode(), MemVT)) + return Src; + + return SDValue(); +} + static SDValue performANDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { SelectionDAG &DAG = DCI.DAG; @@ -9829,6 +10204,9 @@ static SDValue performANDCombine(SDNode *N, if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT)) return SDValue(); + if (VT.isScalableVector()) + return performSVEAndCombine(N, DCI); + BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(N->getOperand(1).getNode()); if (!BVN) @@ -9889,74 +10267,6 @@ static SDValue performSRLCombine(SDNode *N, return SDValue(); } -static SDValue performBitcastCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, - SelectionDAG &DAG) { - // Wait 'til after everything is legalized to try this. That way we have - // legal vector types and such. - if (DCI.isBeforeLegalizeOps()) - return SDValue(); - - // Remove extraneous bitcasts around an extract_subvector. - // For example, - // (v4i16 (bitconvert - // (extract_subvector (v2i64 (bitconvert (v8i16 ...)), (i64 1))))) - // becomes - // (extract_subvector ((v8i16 ...), (i64 4))) - - // Only interested in 64-bit vectors as the ultimate result. - EVT VT = N->getValueType(0); - if (!VT.isVector()) - return SDValue(); - if (VT.getSimpleVT().getSizeInBits() != 64) - return SDValue(); - // Is the operand an extract_subvector starting at the beginning or halfway - // point of the vector? A low half may also come through as an - // EXTRACT_SUBREG, so look for that, too. - SDValue Op0 = N->getOperand(0); - if (Op0->getOpcode() != ISD::EXTRACT_SUBVECTOR && - !(Op0->isMachineOpcode() && - Op0->getMachineOpcode() == AArch64::EXTRACT_SUBREG)) - return SDValue(); - uint64_t idx = cast<ConstantSDNode>(Op0->getOperand(1))->getZExtValue(); - if (Op0->getOpcode() == ISD::EXTRACT_SUBVECTOR) { - if (Op0->getValueType(0).getVectorNumElements() != idx && idx != 0) - return SDValue(); - } else if (Op0->getMachineOpcode() == AArch64::EXTRACT_SUBREG) { - if (idx != AArch64::dsub) - return SDValue(); - // The dsub reference is equivalent to a lane zero subvector reference. - idx = 0; - } - // Look through the bitcast of the input to the extract. - if (Op0->getOperand(0)->getOpcode() != ISD::BITCAST) - return SDValue(); - SDValue Source = Op0->getOperand(0)->getOperand(0); - // If the source type has twice the number of elements as our destination - // type, we know this is an extract of the high or low half of the vector. - EVT SVT = Source->getValueType(0); - if (!SVT.isVector() || - SVT.getVectorNumElements() != VT.getVectorNumElements() * 2) - return SDValue(); - - LLVM_DEBUG( - dbgs() << "aarch64-lower: bitcast extract_subvector simplification\n"); - - // Create the simplified form to just extract the low or high half of the - // vector directly rather than bothering with the bitcasts. - SDLoc dl(N); - unsigned NumElements = VT.getVectorNumElements(); - if (idx) { - SDValue HalfIdx = DAG.getConstant(NumElements, dl, MVT::i64); - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, Source, HalfIdx); - } else { - SDValue SubReg = DAG.getTargetConstant(AArch64::dsub, dl, MVT::i32); - return SDValue(DAG.getMachineNode(TargetOpcode::EXTRACT_SUBREG, dl, VT, - Source, SubReg), - 0); - } -} - static SDValue performConcatVectorsCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { @@ -10263,10 +10573,10 @@ static SDValue performSetccAddFolding(SDNode *Op, SelectionDAG &DAG) { MVT::i32); Cmp = *InfoAndKind.Info.AArch64.Cmp; } else - Cmp = getAArch64Cmp(*InfoAndKind.Info.Generic.Opnd0, - *InfoAndKind.Info.Generic.Opnd1, - ISD::getSetCCInverse(InfoAndKind.Info.Generic.CC, true), - CCVal, DAG, dl); + Cmp = getAArch64Cmp( + *InfoAndKind.Info.Generic.Opnd0, *InfoAndKind.Info.Generic.Opnd1, + ISD::getSetCCInverse(InfoAndKind.Info.Generic.CC, CmpVT), CCVal, DAG, + dl); EVT VT = Op->getValueType(0); LHS = DAG.getNode(ISD::ADD, dl, VT, RHS, DAG.getConstant(1, dl, VT)); @@ -10456,6 +10766,154 @@ static SDValue combineAcrossLanesIntrinsic(unsigned Opc, SDNode *N, DAG.getConstant(0, dl, MVT::i64)); } +static SDValue LowerSVEIntReduction(SDNode *N, unsigned Opc, + SelectionDAG &DAG) { + SDLoc dl(N); + LLVMContext &Ctx = *DAG.getContext(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + EVT VT = N->getValueType(0); + SDValue Pred = N->getOperand(1); + SDValue Data = N->getOperand(2); + EVT DataVT = Data.getValueType(); + + if (DataVT.getVectorElementType().isScalarInteger() && + (VT == MVT::i8 || VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64)) { + if (!TLI.isTypeLegal(DataVT)) + return SDValue(); + + EVT OutputVT = EVT::getVectorVT(Ctx, VT, + AArch64::NeonBitsPerVector / VT.getSizeInBits()); + SDValue Reduce = DAG.getNode(Opc, dl, OutputVT, Pred, Data); + SDValue Zero = DAG.getConstant(0, dl, MVT::i64); + SDValue Result = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Reduce, Zero); + + return Result; + } + + return SDValue(); +} + +static SDValue LowerSVEIntrinsicEXT(SDNode *N, SelectionDAG &DAG) { + SDLoc dl(N); + LLVMContext &Ctx = *DAG.getContext(); + EVT VT = N->getValueType(0); + + assert(VT.isScalableVector() && "Expected a scalable vector."); + + // Current lowering only supports the SVE-ACLE types. + if (VT.getSizeInBits().getKnownMinSize() != AArch64::SVEBitsPerBlock) + return SDValue(); + + unsigned ElemSize = VT.getVectorElementType().getSizeInBits() / 8; + unsigned ByteSize = VT.getSizeInBits().getKnownMinSize() / 8; + EVT ByteVT = EVT::getVectorVT(Ctx, MVT::i8, { ByteSize, true }); + + // Convert everything to the domain of EXT (i.e bytes). + SDValue Op0 = DAG.getNode(ISD::BITCAST, dl, ByteVT, N->getOperand(1)); + SDValue Op1 = DAG.getNode(ISD::BITCAST, dl, ByteVT, N->getOperand(2)); + SDValue Op2 = DAG.getNode(ISD::MUL, dl, MVT::i32, N->getOperand(3), + DAG.getConstant(ElemSize, dl, MVT::i32)); + + SDValue EXT = DAG.getNode(AArch64ISD::EXT, dl, ByteVT, Op0, Op1, Op2); + return DAG.getNode(ISD::BITCAST, dl, VT, EXT); +} + +static SDValue tryConvertSVEWideCompare(SDNode *N, unsigned ReplacementIID, + bool Invert, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + if (DCI.isBeforeLegalize()) + return SDValue(); + + SDValue Comparator = N->getOperand(3); + if (Comparator.getOpcode() == AArch64ISD::DUP || + Comparator.getOpcode() == ISD::SPLAT_VECTOR) { + unsigned IID = getIntrinsicID(N); + EVT VT = N->getValueType(0); + EVT CmpVT = N->getOperand(2).getValueType(); + SDValue Pred = N->getOperand(1); + SDValue Imm; + SDLoc DL(N); + + switch (IID) { + default: + llvm_unreachable("Called with wrong intrinsic!"); + break; + + // Signed comparisons + case Intrinsic::aarch64_sve_cmpeq_wide: + case Intrinsic::aarch64_sve_cmpne_wide: + case Intrinsic::aarch64_sve_cmpge_wide: + case Intrinsic::aarch64_sve_cmpgt_wide: + case Intrinsic::aarch64_sve_cmplt_wide: + case Intrinsic::aarch64_sve_cmple_wide: { + if (auto *CN = dyn_cast<ConstantSDNode>(Comparator.getOperand(0))) { + int64_t ImmVal = CN->getSExtValue(); + if (ImmVal >= -16 && ImmVal <= 15) + Imm = DAG.getConstant(ImmVal, DL, MVT::i32); + else + return SDValue(); + } + break; + } + // Unsigned comparisons + case Intrinsic::aarch64_sve_cmphs_wide: + case Intrinsic::aarch64_sve_cmphi_wide: + case Intrinsic::aarch64_sve_cmplo_wide: + case Intrinsic::aarch64_sve_cmpls_wide: { + if (auto *CN = dyn_cast<ConstantSDNode>(Comparator.getOperand(0))) { + uint64_t ImmVal = CN->getZExtValue(); + if (ImmVal <= 127) + Imm = DAG.getConstant(ImmVal, DL, MVT::i32); + else + return SDValue(); + } + break; + } + } + + SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, DL, CmpVT, Imm); + SDValue ID = DAG.getTargetConstant(ReplacementIID, DL, MVT::i64); + SDValue Op0, Op1; + if (Invert) { + Op0 = Splat; + Op1 = N->getOperand(2); + } else { + Op0 = N->getOperand(2); + Op1 = Splat; + } + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + ID, Pred, Op0, Op1); + } + + return SDValue(); +} + +static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, + AArch64CC::CondCode Cond) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + SDLoc DL(Op); + assert(Op.getValueType().isScalableVector() && + TLI.isTypeLegal(Op.getValueType()) && + "Expected legal scalable vector type!"); + + // Ensure target specific opcodes are using legal type. + EVT OutVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT); + SDValue TVal = DAG.getConstant(1, DL, OutVT); + SDValue FVal = DAG.getConstant(0, DL, OutVT); + + // Set condition code (CC) flags. + SDValue Test = DAG.getNode(AArch64ISD::PTEST, DL, MVT::Other, Pg, Op); + + // Convert CC to integer based on requested condition. + // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare. + SDValue CC = DAG.getConstant(getInvertedCondCode(Cond), DL, MVT::i32); + SDValue Res = DAG.getNode(AArch64ISD::CSEL, DL, OutVT, FVal, TVal, CC, Test); + return DAG.getZExtOrTrunc(Res, DL, VT); +} + static SDValue performIntrinsicCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -10510,6 +10968,61 @@ static SDValue performIntrinsicCombine(SDNode *N, case Intrinsic::aarch64_crc32h: case Intrinsic::aarch64_crc32ch: return tryCombineCRC32(0xffff, N, DAG); + case Intrinsic::aarch64_sve_smaxv: + return LowerSVEIntReduction(N, AArch64ISD::SMAXV_PRED, DAG); + case Intrinsic::aarch64_sve_umaxv: + return LowerSVEIntReduction(N, AArch64ISD::UMAXV_PRED, DAG); + case Intrinsic::aarch64_sve_sminv: + return LowerSVEIntReduction(N, AArch64ISD::SMINV_PRED, DAG); + case Intrinsic::aarch64_sve_uminv: + return LowerSVEIntReduction(N, AArch64ISD::UMINV_PRED, DAG); + case Intrinsic::aarch64_sve_orv: + return LowerSVEIntReduction(N, AArch64ISD::ORV_PRED, DAG); + case Intrinsic::aarch64_sve_eorv: + return LowerSVEIntReduction(N, AArch64ISD::EORV_PRED, DAG); + case Intrinsic::aarch64_sve_andv: + return LowerSVEIntReduction(N, AArch64ISD::ANDV_PRED, DAG); + case Intrinsic::aarch64_sve_ext: + return LowerSVEIntrinsicEXT(N, DAG); + case Intrinsic::aarch64_sve_cmpeq_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmpeq, + false, DCI, DAG); + case Intrinsic::aarch64_sve_cmpne_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmpne, + false, DCI, DAG); + case Intrinsic::aarch64_sve_cmpge_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmpge, + false, DCI, DAG); + case Intrinsic::aarch64_sve_cmpgt_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmpgt, + false, DCI, DAG); + case Intrinsic::aarch64_sve_cmplt_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmpgt, + true, DCI, DAG); + case Intrinsic::aarch64_sve_cmple_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmpge, + true, DCI, DAG); + case Intrinsic::aarch64_sve_cmphs_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmphs, + false, DCI, DAG); + case Intrinsic::aarch64_sve_cmphi_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmphi, + false, DCI, DAG); + case Intrinsic::aarch64_sve_cmplo_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmphi, true, + DCI, DAG); + case Intrinsic::aarch64_sve_cmpls_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmphs, true, + DCI, DAG); + case Intrinsic::aarch64_sve_ptest_any: + return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2), + AArch64CC::ANY_ACTIVE); + case Intrinsic::aarch64_sve_ptest_first: + return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2), + AArch64CC::FIRST_ACTIVE); + case Intrinsic::aarch64_sve_ptest_last: + return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2), + AArch64CC::LAST_ACTIVE); } return SDValue(); } @@ -10652,6 +11165,48 @@ static SDValue splitStoreSplat(SelectionDAG &DAG, StoreSDNode &St, return NewST1; } +static SDValue performLDNT1Combine(SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + EVT PtrTy = N->getOperand(3).getValueType(); + + EVT LoadVT = VT; + if (VT.isFloatingPoint()) + LoadVT = VT.changeTypeToInteger(); + + auto *MINode = cast<MemIntrinsicSDNode>(N); + SDValue PassThru = DAG.getConstant(0, DL, LoadVT); + SDValue L = DAG.getMaskedLoad(LoadVT, DL, MINode->getChain(), + MINode->getOperand(3), DAG.getUNDEF(PtrTy), + MINode->getOperand(2), PassThru, + MINode->getMemoryVT(), MINode->getMemOperand(), + ISD::UNINDEXED, ISD::NON_EXTLOAD, false); + + if (VT.isFloatingPoint()) { + SDValue Ops[] = { DAG.getNode(ISD::BITCAST, DL, VT, L), L.getValue(1) }; + return DAG.getMergeValues(Ops, DL); + } + + return L; +} + +static SDValue performSTNT1Combine(SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + + SDValue Data = N->getOperand(2); + EVT DataVT = Data.getValueType(); + EVT PtrTy = N->getOperand(4).getValueType(); + + if (DataVT.isFloatingPoint()) + Data = DAG.getNode(ISD::BITCAST, DL, DataVT.changeTypeToInteger(), Data); + + auto *MINode = cast<MemIntrinsicSDNode>(N); + return DAG.getMaskedStore(MINode->getChain(), DL, Data, MINode->getOperand(4), + DAG.getUNDEF(PtrTy), MINode->getOperand(3), + MINode->getMemoryVT(), MINode->getMemOperand(), + ISD::UNINDEXED, false, false); +} + /// Replace a splat of zeros to a vector store by scalar stores of WZR/XZR. The /// load store optimizer pass will merge them to store pair stores. This should /// be better than a movi to create the vector zero followed by a vector store @@ -11703,6 +12258,215 @@ static SDValue performGlobalAddressCombine(SDNode *N, SelectionDAG &DAG, DAG.getConstant(MinOffset, DL, MVT::i64)); } +// Returns an SVE type that ContentTy can be trivially sign or zero extended +// into. +static MVT getSVEContainerType(EVT ContentTy) { + assert(ContentTy.isSimple() && "No SVE containers for extended types"); + + switch (ContentTy.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("No known SVE container for this MVT type"); + case MVT::nxv2i8: + case MVT::nxv2i16: + case MVT::nxv2i32: + case MVT::nxv2i64: + case MVT::nxv2f32: + case MVT::nxv2f64: + return MVT::nxv2i64; + case MVT::nxv4i8: + case MVT::nxv4i16: + case MVT::nxv4i32: + case MVT::nxv4f32: + return MVT::nxv4i32; + } +} + +static SDValue performST1ScatterCombine(SDNode *N, SelectionDAG &DAG, + unsigned Opcode, + bool OnlyPackedOffsets = true) { + const SDValue Src = N->getOperand(2); + const EVT SrcVT = Src->getValueType(0); + assert(SrcVT.isScalableVector() && + "Scatter stores are only possible for SVE vectors"); + + SDLoc DL(N); + MVT SrcElVT = SrcVT.getVectorElementType().getSimpleVT(); + + // Make sure that source data will fit into an SVE register + if (SrcVT.getSizeInBits().getKnownMinSize() > AArch64::SVEBitsPerBlock) + return SDValue(); + + // For FPs, ACLE only supports _packed_ single and double precision types. + if (SrcElVT.isFloatingPoint()) + if ((SrcVT != MVT::nxv4f32) && (SrcVT != MVT::nxv2f64)) + return SDValue(); + + // Depending on the addressing mode, this is either a pointer or a vector of + // pointers (that fits into one register) + const SDValue Base = N->getOperand(4); + // Depending on the addressing mode, this is either a single offset or a + // vector of offsets (that fits into one register) + SDValue Offset = N->getOperand(5); + + auto &TLI = DAG.getTargetLoweringInfo(); + if (!TLI.isTypeLegal(Base.getValueType())) + return SDValue(); + + // Some scatter store variants allow unpacked offsets, but only as nxv2i32 + // vectors. These are implicitly sign (sxtw) or zero (zxtw) extend to + // nxv2i64. Legalize accordingly. + if (!OnlyPackedOffsets && + Offset.getValueType().getSimpleVT().SimpleTy == MVT::nxv2i32) + Offset = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv2i64, Offset).getValue(0); + + if (!TLI.isTypeLegal(Offset.getValueType())) + return SDValue(); + + // Source value type that is representable in hardware + EVT HwSrcVt = getSVEContainerType(SrcVT); + + // Keep the original type of the input data to store - this is needed to + // differentiate between ST1B, ST1H, ST1W and ST1D. For FP values we want the + // integer equivalent, so just use HwSrcVt. + SDValue InputVT = DAG.getValueType(SrcVT); + if (SrcVT.isFloatingPoint()) + InputVT = DAG.getValueType(HwSrcVt); + + SDVTList VTs = DAG.getVTList(MVT::Other); + SDValue SrcNew; + + if (Src.getValueType().isFloatingPoint()) + SrcNew = DAG.getNode(ISD::BITCAST, DL, HwSrcVt, Src); + else + SrcNew = DAG.getNode(ISD::ANY_EXTEND, DL, HwSrcVt, Src); + + SDValue Ops[] = {N->getOperand(0), // Chain + SrcNew, + N->getOperand(3), // Pg + Base, + Offset, + InputVT}; + + return DAG.getNode(Opcode, DL, VTs, Ops); +} + +static SDValue performLD1GatherCombine(SDNode *N, SelectionDAG &DAG, + unsigned Opcode, + bool OnlyPackedOffsets = true) { + EVT RetVT = N->getValueType(0); + assert(RetVT.isScalableVector() && + "Gather loads are only possible for SVE vectors"); + SDLoc DL(N); + + if (RetVT.getSizeInBits().getKnownMinSize() > AArch64::SVEBitsPerBlock) + return SDValue(); + + // Depending on the addressing mode, this is either a pointer or a vector of + // pointers (that fits into one register) + const SDValue Base = N->getOperand(3); + // Depending on the addressing mode, this is either a single offset or a + // vector of offsets (that fits into one register) + SDValue Offset = N->getOperand(4); + + auto &TLI = DAG.getTargetLoweringInfo(); + if (!TLI.isTypeLegal(Base.getValueType())) + return SDValue(); + + // Some gather load variants allow unpacked offsets, but only as nxv2i32 + // vectors. These are implicitly sign (sxtw) or zero (zxtw) extend to + // nxv2i64. Legalize accordingly. + if (!OnlyPackedOffsets && + Offset.getValueType().getSimpleVT().SimpleTy == MVT::nxv2i32) + Offset = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv2i64, Offset).getValue(0); + + // Return value type that is representable in hardware + EVT HwRetVt = getSVEContainerType(RetVT); + + // Keep the original output value type around - this will better inform + // optimisations (e.g. instruction folding when load is followed by + // zext/sext). This will only be used for ints, so the value for FPs + // doesn't matter. + SDValue OutVT = DAG.getValueType(RetVT); + if (RetVT.isFloatingPoint()) + OutVT = DAG.getValueType(HwRetVt); + + SDVTList VTs = DAG.getVTList(HwRetVt, MVT::Other); + SDValue Ops[] = {N->getOperand(0), // Chain + N->getOperand(2), // Pg + Base, Offset, OutVT}; + + SDValue Load = DAG.getNode(Opcode, DL, VTs, Ops); + SDValue LoadChain = SDValue(Load.getNode(), 1); + + if (RetVT.isInteger() && (RetVT != HwRetVt)) + Load = DAG.getNode(ISD::TRUNCATE, DL, RetVT, Load.getValue(0)); + + // If the original return value was FP, bitcast accordingly. Doing it here + // means that we can avoid adding TableGen patterns for FPs. + if (RetVT.isFloatingPoint()) + Load = DAG.getNode(ISD::BITCAST, DL, RetVT, Load.getValue(0)); + + return DAG.getMergeValues({Load, LoadChain}, DL); +} + + +static SDValue +performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + + SDValue Src = N->getOperand(0); + unsigned Opc = Src->getOpcode(); + + // Gather load nodes (e.g. AArch64ISD::GLD1) are straightforward candidates + // for DAG Combine with SIGN_EXTEND_INREG. Bail out for all other nodes. + unsigned NewOpc; + switch (Opc) { + case AArch64ISD::GLD1: + NewOpc = AArch64ISD::GLD1S; + break; + case AArch64ISD::GLD1_SCALED: + NewOpc = AArch64ISD::GLD1S_SCALED; + break; + case AArch64ISD::GLD1_SXTW: + NewOpc = AArch64ISD::GLD1S_SXTW; + break; + case AArch64ISD::GLD1_SXTW_SCALED: + NewOpc = AArch64ISD::GLD1S_SXTW_SCALED; + break; + case AArch64ISD::GLD1_UXTW: + NewOpc = AArch64ISD::GLD1S_UXTW; + break; + case AArch64ISD::GLD1_UXTW_SCALED: + NewOpc = AArch64ISD::GLD1S_UXTW_SCALED; + break; + case AArch64ISD::GLD1_IMM: + NewOpc = AArch64ISD::GLD1S_IMM; + break; + default: + return SDValue(); + } + + EVT SignExtSrcVT = cast<VTSDNode>(N->getOperand(1))->getVT(); + EVT GLD1SrcMemVT = cast<VTSDNode>(Src->getOperand(4))->getVT(); + + if ((SignExtSrcVT != GLD1SrcMemVT) || !Src.hasOneUse()) + return SDValue(); + + EVT DstVT = N->getValueType(0); + SDVTList VTs = DAG.getVTList(DstVT, MVT::Other); + SDValue Ops[] = {Src->getOperand(0), Src->getOperand(1), Src->getOperand(2), + Src->getOperand(3), Src->getOperand(4)}; + + SDValue ExtLoad = DAG.getNode(NewOpc, SDLoc(N), VTs, Ops); + DCI.CombineTo(N, ExtLoad); + DCI.CombineTo(Src.getNode(), ExtLoad, ExtLoad.getValue(1)); + + // Return N so it doesn't get rechecked + return SDValue(N, 0); +} + SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -11737,8 +12501,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, case ISD::ZERO_EXTEND: case ISD::SIGN_EXTEND: return performExtendCombine(N, DCI, DAG); - case ISD::BITCAST: - return performBitcastCombine(N, DCI, DAG); + case ISD::SIGN_EXTEND_INREG: + return performSignExtendInRegCombine(N, DCI, DAG); case ISD::CONCAT_VECTORS: return performConcatVectorsCombine(N, DCI, DAG); case ISD::SELECT: @@ -11789,6 +12553,46 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, case Intrinsic::aarch64_neon_st3lane: case Intrinsic::aarch64_neon_st4lane: return performNEONPostLDSTCombine(N, DCI, DAG); + case Intrinsic::aarch64_sve_ldnt1: + return performLDNT1Combine(N, DAG); + case Intrinsic::aarch64_sve_stnt1: + return performSTNT1Combine(N, DAG); + case Intrinsic::aarch64_sve_ld1_gather: + return performLD1GatherCombine(N, DAG, AArch64ISD::GLD1); + case Intrinsic::aarch64_sve_ld1_gather_index: + return performLD1GatherCombine(N, DAG, AArch64ISD::GLD1_SCALED); + case Intrinsic::aarch64_sve_ld1_gather_sxtw: + return performLD1GatherCombine(N, DAG, AArch64ISD::GLD1_SXTW, + /*OnlyPackedOffsets=*/false); + case Intrinsic::aarch64_sve_ld1_gather_uxtw: + return performLD1GatherCombine(N, DAG, AArch64ISD::GLD1_UXTW, + /*OnlyPackedOffsets=*/false); + case Intrinsic::aarch64_sve_ld1_gather_sxtw_index: + return performLD1GatherCombine(N, DAG, AArch64ISD::GLD1_SXTW_SCALED, + /*OnlyPackedOffsets=*/false); + case Intrinsic::aarch64_sve_ld1_gather_uxtw_index: + return performLD1GatherCombine(N, DAG, AArch64ISD::GLD1_UXTW_SCALED, + /*OnlyPackedOffsets=*/false); + case Intrinsic::aarch64_sve_ld1_gather_imm: + return performLD1GatherCombine(N, DAG, AArch64ISD::GLD1_IMM); + case Intrinsic::aarch64_sve_st1_scatter: + return performST1ScatterCombine(N, DAG, AArch64ISD::SST1); + case Intrinsic::aarch64_sve_st1_scatter_index: + return performST1ScatterCombine(N, DAG, AArch64ISD::SST1_SCALED); + case Intrinsic::aarch64_sve_st1_scatter_sxtw: + return performST1ScatterCombine(N, DAG, AArch64ISD::SST1_SXTW, + /*OnlyPackedOffsets=*/false); + case Intrinsic::aarch64_sve_st1_scatter_uxtw: + return performST1ScatterCombine(N, DAG, AArch64ISD::SST1_UXTW, + /*OnlyPackedOffsets=*/false); + case Intrinsic::aarch64_sve_st1_scatter_sxtw_index: + return performST1ScatterCombine(N, DAG, AArch64ISD::SST1_SXTW_SCALED, + /*OnlyPackedOffsets=*/false); + case Intrinsic::aarch64_sve_st1_scatter_uxtw_index: + return performST1ScatterCombine(N, DAG, AArch64ISD::SST1_UXTW_SCALED, + /*OnlyPackedOffsets=*/false); + case Intrinsic::aarch64_sve_st1_scatter_imm: + return performST1ScatterCombine(N, DAG, AArch64ISD::SST1_IMM); default: break; } @@ -12084,6 +12888,69 @@ void AArch64TargetLowering::ReplaceNodeResults( case ISD::ATOMIC_CMP_SWAP: ReplaceCMP_SWAP_128Results(N, Results, DAG, Subtarget); return; + case ISD::LOAD: { + assert(SDValue(N, 0).getValueType() == MVT::i128 && + "unexpected load's value type"); + LoadSDNode *LoadNode = cast<LoadSDNode>(N); + if (!LoadNode->isVolatile() || LoadNode->getMemoryVT() != MVT::i128) { + // Non-volatile loads are optimized later in AArch64's load/store + // optimizer. + return; + } + + SDValue Result = DAG.getMemIntrinsicNode( + AArch64ISD::LDP, SDLoc(N), + DAG.getVTList({MVT::i64, MVT::i64, MVT::Other}), + {LoadNode->getChain(), LoadNode->getBasePtr()}, LoadNode->getMemoryVT(), + LoadNode->getMemOperand()); + + SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i128, + Result.getValue(0), Result.getValue(1)); + Results.append({Pair, Result.getValue(2) /* Chain */}); + return; + } + case ISD::INTRINSIC_WO_CHAIN: { + EVT VT = N->getValueType(0); + assert((VT == MVT::i8 || VT == MVT::i16) && + "custom lowering for unexpected type"); + + ConstantSDNode *CN = cast<ConstantSDNode>(N->getOperand(0)); + Intrinsic::ID IntID = static_cast<Intrinsic::ID>(CN->getZExtValue()); + switch (IntID) { + default: + return; + case Intrinsic::aarch64_sve_clasta_n: { + SDLoc DL(N); + auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2)); + auto V = DAG.getNode(AArch64ISD::CLASTA_N, DL, MVT::i32, + N->getOperand(1), Op2, N->getOperand(3)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V)); + return; + } + case Intrinsic::aarch64_sve_clastb_n: { + SDLoc DL(N); + auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2)); + auto V = DAG.getNode(AArch64ISD::CLASTB_N, DL, MVT::i32, + N->getOperand(1), Op2, N->getOperand(3)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V)); + return; + } + case Intrinsic::aarch64_sve_lasta: { + SDLoc DL(N); + auto V = DAG.getNode(AArch64ISD::LASTA, DL, MVT::i32, + N->getOperand(1), N->getOperand(2)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V)); + return; + } + case Intrinsic::aarch64_sve_lastb: { + SDLoc DL(N); + auto V = DAG.getNode(AArch64ISD::LASTB, DL, MVT::i32, + N->getOperand(1), N->getOperand(2)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V)); + return; + } + } + } } } @@ -12351,7 +13218,7 @@ bool AArch64TargetLowering:: bool AArch64TargetLowering::shouldExpandShift(SelectionDAG &DAG, SDNode *N) const { if (DAG.getMachineFunction().getFunction().hasMinSize() && - !Subtarget->isTargetWindows()) + !Subtarget->isTargetWindows() && !Subtarget->isTargetDarwin()) return false; return true; } |
