diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2023-02-11 12:38:04 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2023-02-11 12:38:11 +0000 |
commit | e3b557809604d036af6e00c60f012c2025b59a5e (patch) | |
tree | 8a11ba2269a3b669601e2fd41145b174008f4da8 /llvm/lib/Target/Hexagon/HexagonISelLowering.cpp | |
parent | 08e8dd7b9db7bb4a9de26d44c1cbfd24e869c014 (diff) | |
download | src-e3b557809604d036af6e00c60f012c2025b59a5e.tar.gz src-e3b557809604d036af6e00c60f012c2025b59a5e.zip |
Diffstat (limited to 'llvm/lib/Target/Hexagon/HexagonISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/Hexagon/HexagonISelLowering.cpp | 505 |
1 files changed, 357 insertions, 148 deletions
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp index 94411b2e4f98..202fc473f9e4 100644 --- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp +++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp @@ -144,7 +144,7 @@ static bool CC_SkipOdd(unsigned &ValNo, MVT &ValVT, MVT &LocVT, Hexagon::R0, Hexagon::R1, Hexagon::R2, Hexagon::R3, Hexagon::R4, Hexagon::R5 }; - const unsigned NumArgRegs = array_lengthof(ArgRegs); + const unsigned NumArgRegs = std::size(ArgRegs); unsigned RegNum = State.getFirstUnallocated(ArgRegs); // RegNum is an index into ArgRegs: skip a register if RegNum is odd. @@ -612,8 +612,7 @@ HexagonTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, Glue = Chain.getValue(1); // Create the CALLSEQ_END node. - Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(NumBytes, dl, true), - DAG.getIntPtrConstant(0, dl, true), Glue, dl); + Chain = DAG.getCALLSEQ_END(Chain, NumBytes, 0, Glue, dl); Glue = Chain.getValue(1); // Handle result values, copying them out of physregs into vregs that we @@ -1809,6 +1808,8 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM, setOperationAction(ISD::FMUL, MVT::f64, Legal); } + setTargetDAGCombine(ISD::OR); + setTargetDAGCombine(ISD::TRUNCATE); setTargetDAGCombine(ISD::VSELECT); if (Subtarget.useHVXOps()) @@ -1900,6 +1901,13 @@ const char* HexagonTargetLowering::getTargetNodeName(unsigned Opcode) const { case HexagonISD::VASL: return "HexagonISD::VASL"; case HexagonISD::VASR: return "HexagonISD::VASR"; case HexagonISD::VLSR: return "HexagonISD::VLSR"; + case HexagonISD::MFSHL: return "HexagonISD::MFSHL"; + case HexagonISD::MFSHR: return "HexagonISD::MFSHR"; + case HexagonISD::SSAT: return "HexagonISD::SSAT"; + case HexagonISD::USAT: return "HexagonISD::USAT"; + case HexagonISD::SMUL_LOHI: return "HexagonISD::SMUL_LOHI"; + case HexagonISD::UMUL_LOHI: return "HexagonISD::UMUL_LOHI"; + case HexagonISD::USMUL_LOHI: return "HexagonISD::USMUL_LOHI"; case HexagonISD::VEXTRACTW: return "HexagonISD::VEXTRACTW"; case HexagonISD::VINSERTW0: return "HexagonISD::VINSERTW0"; case HexagonISD::VROR: return "HexagonISD::VROR"; @@ -1913,12 +1921,11 @@ const char* HexagonTargetLowering::getTargetNodeName(unsigned Opcode) const { case HexagonISD::QCAT: return "HexagonISD::QCAT"; case HexagonISD::QTRUE: return "HexagonISD::QTRUE"; case HexagonISD::QFALSE: return "HexagonISD::QFALSE"; + case HexagonISD::TL_EXTEND: return "HexagonISD::TL_EXTEND"; + case HexagonISD::TL_TRUNCATE: return "HexagonISD::TL_TRUNCATE"; case HexagonISD::TYPECAST: return "HexagonISD::TYPECAST"; case HexagonISD::VALIGN: return "HexagonISD::VALIGN"; case HexagonISD::VALIGNADDR: return "HexagonISD::VALIGNADDR"; - case HexagonISD::VPACKL: return "HexagonISD::VPACKL"; - case HexagonISD::VUNPACK: return "HexagonISD::VUNPACK"; - case HexagonISD::VUNPACKU: return "HexagonISD::VUNPACKU"; case HexagonISD::ISEL: return "HexagonISD::ISEL"; case HexagonISD::OP_END: break; } @@ -2141,6 +2148,25 @@ bool HexagonTargetLowering::shouldExpandBuildVectorWithShuffles(EVT VT, return false; } +bool HexagonTargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, + unsigned Index) const { + assert(ResVT.getVectorElementType() == SrcVT.getVectorElementType()); + if (!ResVT.isSimple() || !SrcVT.isSimple()) + return false; + + MVT ResTy = ResVT.getSimpleVT(), SrcTy = SrcVT.getSimpleVT(); + if (ResTy.getVectorElementType() != MVT::i1) + return true; + + // Non-HVX bool vectors are relatively cheap. + return SrcTy.getVectorNumElements() <= 8; +} + +bool HexagonTargetLowering::isTargetCanonicalConstantNode(SDValue Op) const { + return Op.getOpcode() == ISD::CONCAT_VECTORS || + TargetLowering::isTargetCanonicalConstantNode(Op); +} + bool HexagonTargetLowering::isShuffleMaskLegal(ArrayRef<int> Mask, EVT VT) const { return true; @@ -2172,6 +2198,16 @@ HexagonTargetLowering::getPreferredVectorAction(MVT VT) const { return TargetLoweringBase::TypeSplitVector; } +TargetLoweringBase::LegalizeAction +HexagonTargetLowering::getCustomOperationAction(SDNode &Op) const { + if (Subtarget.useHVXOps()) { + unsigned Action = getCustomHvxOperationAction(Op); + if (Action != ~0u) + return static_cast<TargetLoweringBase::LegalizeAction>(Action); + } + return TargetLoweringBase::Legal; +} + std::pair<SDValue, int> HexagonTargetLowering::getBaseAndOffset(SDValue Addr) const { if (Addr.getOpcode() == ISD::ADD) { @@ -2259,15 +2295,15 @@ HexagonTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) } // Byte packs. - SDValue Concat10 = DAG.getNode(HexagonISD::COMBINE, dl, - typeJoin({ty(Op1), ty(Op0)}), {Op1, Op0}); + SDValue Concat10 = + getCombine(Op1, Op0, dl, typeJoin({ty(Op1), ty(Op0)}), DAG); if (MaskIdx == (0x06040200 | MaskUnd)) return getInstr(Hexagon::S2_vtrunehb, dl, VecTy, {Concat10}, DAG); if (MaskIdx == (0x07050301 | MaskUnd)) return getInstr(Hexagon::S2_vtrunohb, dl, VecTy, {Concat10}, DAG); - SDValue Concat01 = DAG.getNode(HexagonISD::COMBINE, dl, - typeJoin({ty(Op0), ty(Op1)}), {Op0, Op1}); + SDValue Concat01 = + getCombine(Op0, Op1, dl, typeJoin({ty(Op0), ty(Op1)}), DAG); if (MaskIdx == (0x02000604 | MaskUnd)) return getInstr(Hexagon::S2_vtrunehb, dl, VecTy, {Concat01}, DAG); if (MaskIdx == (0x03010705 | MaskUnd)) @@ -2309,6 +2345,19 @@ HexagonTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) return SDValue(); } +SDValue +HexagonTargetLowering::getSplatValue(SDValue Op, SelectionDAG &DAG) const { + switch (Op.getOpcode()) { + case ISD::BUILD_VECTOR: + if (SDValue S = cast<BuildVectorSDNode>(Op)->getSplatValue()) + return S; + break; + case ISD::SPLAT_VECTOR: + return Op.getOperand(0); + } + return SDValue(); +} + // Create a Hexagon-specific node for shifting a vector by an integer. SDValue HexagonTargetLowering::getVectorShiftByInt(SDValue Op, SelectionDAG &DAG) @@ -2328,24 +2377,56 @@ HexagonTargetLowering::getVectorShiftByInt(SDValue Op, SelectionDAG &DAG) llvm_unreachable("Unexpected shift opcode"); } - SDValue Op0 = Op.getOperand(0); - SDValue Op1 = Op.getOperand(1); - const SDLoc &dl(Op); - - switch (Op1.getOpcode()) { - case ISD::BUILD_VECTOR: - if (SDValue S = cast<BuildVectorSDNode>(Op1)->getSplatValue()) - return DAG.getNode(NewOpc, dl, ty(Op), Op0, S); - break; - case ISD::SPLAT_VECTOR: - return DAG.getNode(NewOpc, dl, ty(Op), Op0, Op1.getOperand(0)); - } + if (SDValue Sp = getSplatValue(Op.getOperand(1), DAG)) + return DAG.getNode(NewOpc, SDLoc(Op), ty(Op), Op.getOperand(0), Sp); return SDValue(); } SDValue HexagonTargetLowering::LowerVECTOR_SHIFT(SDValue Op, SelectionDAG &DAG) const { - return getVectorShiftByInt(Op, DAG); + const SDLoc &dl(Op); + + // First try to convert the shift (by vector) to a shift by a scalar. + // If we first split the shift, the shift amount will become 'extract + // subvector', and will no longer be recognized as scalar. + SDValue Res = Op; + if (SDValue S = getVectorShiftByInt(Op, DAG)) + Res = S; + + unsigned Opc = Res.getOpcode(); + switch (Opc) { + case HexagonISD::VASR: + case HexagonISD::VLSR: + case HexagonISD::VASL: + break; + default: + // No instructions for shifts by non-scalars. + return SDValue(); + } + + MVT ResTy = ty(Res); + if (ResTy.getVectorElementType() != MVT::i8) + return Res; + + // For shifts of i8, extend the inputs to i16, then truncate back to i8. + assert(ResTy.getVectorElementType() == MVT::i8); + SDValue Val = Res.getOperand(0), Amt = Res.getOperand(1); + + auto ShiftPartI8 = [&dl, &DAG, this](unsigned Opc, SDValue V, SDValue A) { + MVT Ty = ty(V); + MVT ExtTy = MVT::getVectorVT(MVT::i16, Ty.getVectorNumElements()); + SDValue ExtV = Opc == HexagonISD::VASR ? DAG.getSExtOrTrunc(V, dl, ExtTy) + : DAG.getZExtOrTrunc(V, dl, ExtTy); + SDValue ExtS = DAG.getNode(Opc, dl, ExtTy, {ExtV, A}); + return DAG.getZExtOrTrunc(ExtS, dl, Ty); + }; + + if (ResTy.getSizeInBits() == 32) + return ShiftPartI8(Opc, Val, Amt); + + auto [LoV, HiV] = opSplit(Val, dl, DAG); + return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResTy, + {ShiftPartI8(Opc, LoV, Amt), ShiftPartI8(Opc, HiV, Amt)}); } SDValue @@ -2555,7 +2636,7 @@ HexagonTargetLowering::buildVector64(ArrayRef<SDValue> Elem, const SDLoc &dl, SDValue H = (ElemTy == MVT::i32) ? Elem[1] : buildVector32(Elem.drop_front(Num/2), dl, HalfTy, DAG); - return DAG.getNode(HexagonISD::COMBINE, dl, VecTy, {H, L}); + return getCombine(H, L, dl, VecTy, DAG); } SDValue @@ -2565,60 +2646,13 @@ HexagonTargetLowering::extractVector(SDValue VecV, SDValue IdxV, MVT VecTy = ty(VecV); assert(!ValTy.isVector() || VecTy.getVectorElementType() == ValTy.getVectorElementType()); + if (VecTy.getVectorElementType() == MVT::i1) + return extractVectorPred(VecV, IdxV, dl, ValTy, ResTy, DAG); + unsigned VecWidth = VecTy.getSizeInBits(); unsigned ValWidth = ValTy.getSizeInBits(); unsigned ElemWidth = VecTy.getVectorElementType().getSizeInBits(); assert((VecWidth % ElemWidth) == 0); - auto *IdxN = dyn_cast<ConstantSDNode>(IdxV); - - // Special case for v{8,4,2}i1 (the only boolean vectors legal in Hexagon - // without any coprocessors). - if (ElemWidth == 1) { - assert(VecWidth == VecTy.getVectorNumElements() && - "Vector elements should equal vector width size"); - assert(VecWidth == 8 || VecWidth == 4 || VecWidth == 2); - // Check if this is an extract of the lowest bit. - if (IdxN) { - // Extracting the lowest bit is a no-op, but it changes the type, - // so it must be kept as an operation to avoid errors related to - // type mismatches. - if (IdxN->isZero() && ValTy.getSizeInBits() == 1) - return DAG.getNode(HexagonISD::TYPECAST, dl, MVT::i1, VecV); - } - - // If the value extracted is a single bit, use tstbit. - if (ValWidth == 1) { - SDValue A0 = getInstr(Hexagon::C2_tfrpr, dl, MVT::i32, {VecV}, DAG); - SDValue M0 = DAG.getConstant(8 / VecWidth, dl, MVT::i32); - SDValue I0 = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, M0); - return DAG.getNode(HexagonISD::TSTBIT, dl, MVT::i1, A0, I0); - } - - // Each bool vector (v2i1, v4i1, v8i1) always occupies 8 bits in - // a predicate register. The elements of the vector are repeated - // in the register (if necessary) so that the total number is 8. - // The extracted subvector will need to be expanded in such a way. - unsigned Scale = VecWidth / ValWidth; - - // Generate (p2d VecV) >> 8*Idx to move the interesting bytes to - // position 0. - assert(ty(IdxV) == MVT::i32); - unsigned VecRep = 8 / VecWidth; - SDValue S0 = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, - DAG.getConstant(8*VecRep, dl, MVT::i32)); - SDValue T0 = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, VecV); - SDValue T1 = DAG.getNode(ISD::SRL, dl, MVT::i64, T0, S0); - while (Scale > 1) { - // The longest possible subvector is at most 32 bits, so it is always - // contained in the low subregister. - T1 = DAG.getTargetExtractSubreg(Hexagon::isub_lo, dl, MVT::i32, T1); - T1 = expandPredicate(T1, dl, DAG); - Scale /= 2; - } - - return DAG.getNode(HexagonISD::D2P, dl, ResTy, T1); - } - assert(VecWidth == 32 || VecWidth == 64); // Cast everything to scalar integer types. @@ -2628,12 +2662,11 @@ HexagonTargetLowering::extractVector(SDValue VecV, SDValue IdxV, SDValue WidthV = DAG.getConstant(ValWidth, dl, MVT::i32); SDValue ExtV; - if (IdxN) { + if (auto *IdxN = dyn_cast<ConstantSDNode>(IdxV)) { unsigned Off = IdxN->getZExtValue() * ElemWidth; if (VecWidth == 64 && ValWidth == 32) { assert(Off == 0 || Off == 32); - unsigned SubIdx = Off == 0 ? Hexagon::isub_lo : Hexagon::isub_hi; - ExtV = DAG.getTargetExtractSubreg(SubIdx, dl, MVT::i32, VecV); + ExtV = Off == 0 ? LoHalf(VecV, DAG) : HiHalf(VecV, DAG); } else if (Off == 0 && (ValWidth % 8) == 0) { ExtV = DAG.getZeroExtendInReg(VecV, dl, tyScalar(ValTy)); } else { @@ -2659,37 +2692,68 @@ HexagonTargetLowering::extractVector(SDValue VecV, SDValue IdxV, } SDValue -HexagonTargetLowering::insertVector(SDValue VecV, SDValue ValV, SDValue IdxV, - const SDLoc &dl, MVT ValTy, - SelectionDAG &DAG) const { +HexagonTargetLowering::extractVectorPred(SDValue VecV, SDValue IdxV, + const SDLoc &dl, MVT ValTy, MVT ResTy, + SelectionDAG &DAG) const { + // Special case for v{8,4,2}i1 (the only boolean vectors legal in Hexagon + // without any coprocessors). MVT VecTy = ty(VecV); - if (VecTy.getVectorElementType() == MVT::i1) { - MVT ValTy = ty(ValV); - assert(ValTy.getVectorElementType() == MVT::i1); - SDValue ValR = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, ValV); - unsigned VecLen = VecTy.getVectorNumElements(); - unsigned Scale = VecLen / ValTy.getVectorNumElements(); - assert(Scale > 1); - - for (unsigned R = Scale; R > 1; R /= 2) { - ValR = contractPredicate(ValR, dl, DAG); - ValR = DAG.getNode(HexagonISD::COMBINE, dl, MVT::i64, - DAG.getUNDEF(MVT::i32), ValR); - } + unsigned VecWidth = VecTy.getSizeInBits(); + unsigned ValWidth = ValTy.getSizeInBits(); + assert(VecWidth == VecTy.getVectorNumElements() && + "Vector elements should equal vector width size"); + assert(VecWidth == 8 || VecWidth == 4 || VecWidth == 2); + + // Check if this is an extract of the lowest bit. + if (auto *IdxN = dyn_cast<ConstantSDNode>(IdxV)) { + // Extracting the lowest bit is a no-op, but it changes the type, + // so it must be kept as an operation to avoid errors related to + // type mismatches. + if (IdxN->isZero() && ValTy.getSizeInBits() == 1) + return DAG.getNode(HexagonISD::TYPECAST, dl, MVT::i1, VecV); + } + + // If the value extracted is a single bit, use tstbit. + if (ValWidth == 1) { + SDValue A0 = getInstr(Hexagon::C2_tfrpr, dl, MVT::i32, {VecV}, DAG); + SDValue M0 = DAG.getConstant(8 / VecWidth, dl, MVT::i32); + SDValue I0 = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, M0); + return DAG.getNode(HexagonISD::TSTBIT, dl, MVT::i1, A0, I0); + } + + // Each bool vector (v2i1, v4i1, v8i1) always occupies 8 bits in + // a predicate register. The elements of the vector are repeated + // in the register (if necessary) so that the total number is 8. + // The extracted subvector will need to be expanded in such a way. + unsigned Scale = VecWidth / ValWidth; + + // Generate (p2d VecV) >> 8*Idx to move the interesting bytes to + // position 0. + assert(ty(IdxV) == MVT::i32); + unsigned VecRep = 8 / VecWidth; + SDValue S0 = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, + DAG.getConstant(8*VecRep, dl, MVT::i32)); + SDValue T0 = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, VecV); + SDValue T1 = DAG.getNode(ISD::SRL, dl, MVT::i64, T0, S0); + while (Scale > 1) { // The longest possible subvector is at most 32 bits, so it is always // contained in the low subregister. - ValR = DAG.getTargetExtractSubreg(Hexagon::isub_lo, dl, MVT::i32, ValR); - - unsigned ValBytes = 64 / Scale; - SDValue Width = DAG.getConstant(ValBytes*8, dl, MVT::i32); - SDValue Idx = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, - DAG.getConstant(8, dl, MVT::i32)); - SDValue VecR = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, VecV); - SDValue Ins = DAG.getNode(HexagonISD::INSERT, dl, MVT::i32, - {VecR, ValR, Width, Idx}); - return DAG.getNode(HexagonISD::D2P, dl, VecTy, Ins); + T1 = LoHalf(T1, DAG); + T1 = expandPredicate(T1, dl, DAG); + Scale /= 2; } + return DAG.getNode(HexagonISD::D2P, dl, ResTy, T1); +} + +SDValue +HexagonTargetLowering::insertVector(SDValue VecV, SDValue ValV, SDValue IdxV, + const SDLoc &dl, MVT ValTy, + SelectionDAG &DAG) const { + MVT VecTy = ty(VecV); + if (VecTy.getVectorElementType() == MVT::i1) + return insertVectorPred(VecV, ValV, IdxV, dl, ValTy, DAG); + unsigned VecWidth = VecTy.getSizeInBits(); unsigned ValWidth = ValTy.getSizeInBits(); assert(VecWidth == 32 || VecWidth == 64); @@ -2725,12 +2789,52 @@ HexagonTargetLowering::insertVector(SDValue VecV, SDValue ValV, SDValue IdxV, } SDValue +HexagonTargetLowering::insertVectorPred(SDValue VecV, SDValue ValV, + SDValue IdxV, const SDLoc &dl, + MVT ValTy, SelectionDAG &DAG) const { + MVT VecTy = ty(VecV); + unsigned VecLen = VecTy.getVectorNumElements(); + + if (ValTy == MVT::i1) { + SDValue ToReg = getInstr(Hexagon::C2_tfrpr, dl, MVT::i32, {VecV}, DAG); + SDValue Ext = DAG.getSExtOrTrunc(ValV, dl, MVT::i32); + SDValue Width = DAG.getConstant(8 / VecLen, dl, MVT::i32); + SDValue Idx = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, Width); + SDValue Ins = + DAG.getNode(HexagonISD::INSERT, dl, MVT::i32, {ToReg, Ext, Width, Idx}); + return getInstr(Hexagon::C2_tfrrp, dl, VecTy, {Ins}, DAG); + } + + assert(ValTy.getVectorElementType() == MVT::i1); + SDValue ValR = ValTy.isVector() + ? DAG.getNode(HexagonISD::P2D, dl, MVT::i64, ValV) + : DAG.getSExtOrTrunc(ValV, dl, MVT::i64); + + unsigned Scale = VecLen / ValTy.getVectorNumElements(); + assert(Scale > 1); + + for (unsigned R = Scale; R > 1; R /= 2) { + ValR = contractPredicate(ValR, dl, DAG); + ValR = getCombine(DAG.getUNDEF(MVT::i32), ValR, dl, MVT::i64, DAG); + } + + SDValue Width = DAG.getConstant(64 / Scale, dl, MVT::i32); + SDValue Idx = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, Width); + SDValue VecR = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, VecV); + SDValue Ins = + DAG.getNode(HexagonISD::INSERT, dl, MVT::i64, {VecR, ValR, Width, Idx}); + return DAG.getNode(HexagonISD::D2P, dl, VecTy, Ins); +} + +SDValue HexagonTargetLowering::expandPredicate(SDValue Vec32, const SDLoc &dl, SelectionDAG &DAG) const { assert(ty(Vec32).getSizeInBits() == 32); if (isUndef(Vec32)) return DAG.getUNDEF(MVT::i64); - return getInstr(Hexagon::S2_vsxtbh, dl, MVT::i64, {Vec32}, DAG); + SDValue P = DAG.getBitcast(MVT::v4i8, Vec32); + SDValue X = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i16, P); + return DAG.getBitcast(MVT::i64, X); } SDValue @@ -2739,7 +2843,12 @@ HexagonTargetLowering::contractPredicate(SDValue Vec64, const SDLoc &dl, assert(ty(Vec64).getSizeInBits() == 64); if (isUndef(Vec64)) return DAG.getUNDEF(MVT::i32); - return getInstr(Hexagon::S2_vtrunehb, dl, MVT::i32, {Vec64}, DAG); + // Collect even bytes: + SDValue A = DAG.getBitcast(MVT::v8i8, Vec64); + SDValue S = DAG.getVectorShuffle(MVT::v8i8, dl, A, DAG.getUNDEF(MVT::v8i8), + {0, 2, 4, 6, 1, 3, 5, 7}); + return extractVector(S, DAG.getConstant(0, dl, MVT::i32), dl, MVT::v4i8, + MVT::i32, DAG); } SDValue @@ -2782,6 +2891,28 @@ HexagonTargetLowering::appendUndef(SDValue Val, MVT ResTy, SelectionDAG &DAG) } SDValue +HexagonTargetLowering::getCombine(SDValue Hi, SDValue Lo, const SDLoc &dl, + MVT ResTy, SelectionDAG &DAG) const { + MVT ElemTy = ty(Hi); + assert(ElemTy == ty(Lo)); + + if (!ElemTy.isVector()) { + assert(ElemTy.isScalarInteger()); + MVT PairTy = MVT::getIntegerVT(2 * ElemTy.getSizeInBits()); + SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, dl, PairTy, Lo, Hi); + return DAG.getBitcast(ResTy, Pair); + } + + unsigned Width = ElemTy.getSizeInBits(); + MVT IntTy = MVT::getIntegerVT(Width); + MVT PairTy = MVT::getIntegerVT(2 * Width); + SDValue Pair = + DAG.getNode(ISD::BUILD_PAIR, dl, PairTy, + {DAG.getBitcast(IntTy, Lo), DAG.getBitcast(IntTy, Hi)}); + return DAG.getBitcast(ResTy, Pair); +} + +SDValue HexagonTargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { MVT VecTy = ty(Op); unsigned BW = VecTy.getSizeInBits(); @@ -2842,8 +2973,7 @@ HexagonTargetLowering::LowerCONCAT_VECTORS(SDValue Op, const SDLoc &dl(Op); if (VecTy.getSizeInBits() == 64) { assert(Op.getNumOperands() == 2); - return DAG.getNode(HexagonISD::COMBINE, dl, VecTy, Op.getOperand(1), - Op.getOperand(0)); + return getCombine(Op.getOperand(1), Op.getOperand(0), dl, VecTy, DAG); } MVT ElemTy = VecTy.getVectorElementType(); @@ -2866,10 +2996,9 @@ HexagonTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SDValue W = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, P); for (unsigned R = Scale; R > 1; R /= 2) { W = contractPredicate(W, dl, DAG); - W = DAG.getNode(HexagonISD::COMBINE, dl, MVT::i64, - DAG.getUNDEF(MVT::i32), W); + W = getCombine(DAG.getUNDEF(MVT::i32), W, dl, MVT::i64, DAG); } - W = DAG.getTargetExtractSubreg(Hexagon::isub_lo, dl, MVT::i32, W); + W = LoHalf(W, DAG); Words[IdxW].push_back(W); } @@ -2891,8 +3020,7 @@ HexagonTargetLowering::LowerCONCAT_VECTORS(SDValue Op, // At this point there should only be two words left, and Scale should be 2. assert(Scale == 2 && Words[IdxW].size() == 2); - SDValue WW = DAG.getNode(HexagonISD::COMBINE, dl, MVT::i64, - Words[IdxW][1], Words[IdxW][0]); + SDValue WW = getCombine(Words[IdxW][1], Words[IdxW][0], dl, MVT::i64, DAG); return DAG.getNode(HexagonISD::D2P, dl, VecTy, WW); } @@ -2946,15 +3074,16 @@ SDValue HexagonTargetLowering::LowerLoad(SDValue Op, SelectionDAG &DAG) const { MVT Ty = ty(Op); const SDLoc &dl(Op); - // Lower loads of scalar predicate vectors (v2i1, v4i1, v8i1) to loads of i1 - // followed by a TYPECAST. LoadSDNode *LN = cast<LoadSDNode>(Op.getNode()); - bool DoCast = (Ty == MVT::v2i1 || Ty == MVT::v4i1 || Ty == MVT::v8i1); - if (DoCast) { + MVT MemTy = LN->getMemoryVT().getSimpleVT(); + ISD::LoadExtType ET = LN->getExtensionType(); + + bool LoadPred = MemTy == MVT::v2i1 || MemTy == MVT::v4i1 || MemTy == MVT::v8i1; + if (LoadPred) { SDValue NL = DAG.getLoad( - LN->getAddressingMode(), LN->getExtensionType(), MVT::i1, dl, - LN->getChain(), LN->getBasePtr(), LN->getOffset(), LN->getPointerInfo(), - /*MemoryVT*/ MVT::i1, LN->getAlign(), LN->getMemOperand()->getFlags(), + LN->getAddressingMode(), ISD::ZEXTLOAD, MVT::i32, dl, LN->getChain(), + LN->getBasePtr(), LN->getOffset(), LN->getPointerInfo(), + /*MemoryVT*/ MVT::i8, LN->getAlign(), LN->getMemOperand()->getFlags(), LN->getAAInfo(), LN->getRanges()); LN = cast<LoadSDNode>(NL.getNode()); } @@ -2966,10 +3095,15 @@ HexagonTargetLowering::LowerLoad(SDValue Op, SelectionDAG &DAG) const { // Call LowerUnalignedLoad for all loads, it recognizes loads that // don't need extra aligning. SDValue LU = LowerUnalignedLoad(SDValue(LN, 0), DAG); - if (DoCast) { - SDValue TC = DAG.getNode(HexagonISD::TYPECAST, dl, Ty, LU); + if (LoadPred) { + SDValue TP = getInstr(Hexagon::C2_tfrrp, dl, MemTy, {LU}, DAG); + if (ET == ISD::SEXTLOAD) { + TP = DAG.getSExtOrTrunc(TP, dl, Ty); + } else if (ET != ISD::NON_EXTLOAD) { + TP = DAG.getZExtOrTrunc(TP, dl, Ty); + } SDValue Ch = cast<LoadSDNode>(LU.getNode())->getChain(); - return DAG.getMergeValues({TC, Ch}, dl); + return DAG.getMergeValues({TP, Ch}, dl); } return LU; } @@ -2981,11 +3115,11 @@ HexagonTargetLowering::LowerStore(SDValue Op, SelectionDAG &DAG) const { SDValue Val = SN->getValue(); MVT Ty = ty(Val); - bool DoCast = (Ty == MVT::v2i1 || Ty == MVT::v4i1 || Ty == MVT::v8i1); - if (DoCast) { - SDValue TC = DAG.getNode(HexagonISD::TYPECAST, dl, MVT::i1, Val); - SDValue NS = DAG.getStore(SN->getChain(), dl, TC, SN->getBasePtr(), - SN->getMemOperand()); + if (Ty == MVT::v2i1 || Ty == MVT::v4i1 || Ty == MVT::v8i1) { + // Store the exact predicate (all bits). + SDValue TR = getInstr(Hexagon::C2_tfrpr, dl, MVT::i32, {Val}, DAG); + SDValue NS = DAG.getTruncStore(SN->getChain(), dl, TR, SN->getBasePtr(), + MVT::i8, SN->getMemOperand()); if (SN->isIndexed()) { NS = DAG.getIndexedStore(NS, dl, SN->getBasePtr(), SN->getOffset(), SN->getAddressingMode()); @@ -3249,13 +3383,25 @@ HexagonTargetLowering::LowerOperationWrapper(SDNode *N, return; } - // We are only custom-lowering stores to verify the alignment of the - // address if it is a compile-time constant. Since a store can be modified - // during type-legalization (the value being stored may need legalization), - // return empty Results here to indicate that we don't really make any - // changes in the custom lowering. - if (N->getOpcode() != ISD::STORE) - return TargetLowering::LowerOperationWrapper(N, Results, DAG); + SDValue Op(N, 0); + unsigned Opc = N->getOpcode(); + + switch (Opc) { + case HexagonISD::SSAT: + case HexagonISD::USAT: + Results.push_back(opJoin(SplitVectorOp(Op, DAG), SDLoc(Op), DAG)); + break; + case ISD::STORE: + // We are only custom-lowering stores to verify the alignment of the + // address if it is a compile-time constant. Since a store can be + // modified during type-legalization (the value being stored may need + // legalization), return empty Results here to indicate that we don't + // really make any changes in the custom lowering. + return; + default: + TargetLowering::LowerOperationWrapper(N, Results, DAG); + break; + } } void @@ -3289,30 +3435,45 @@ HexagonTargetLowering::ReplaceNodeResults(SDNode *N, } SDValue -HexagonTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) - const { +HexagonTargetLowering::PerformDAGCombine(SDNode *N, + DAGCombinerInfo &DCI) const { if (isHvxOperation(N, DCI.DAG)) { if (SDValue V = PerformHvxDAGCombine(N, DCI)) return V; return SDValue(); } - if (DCI.isBeforeLegalizeOps()) - return SDValue(); - SDValue Op(N, 0); const SDLoc &dl(Op); unsigned Opc = Op.getOpcode(); + if (Opc == ISD::TRUNCATE) { + SDValue Op0 = Op.getOperand(0); + // fold (truncate (build pair x, y)) -> (truncate x) or x + if (Op0.getOpcode() == ISD::BUILD_PAIR) { + EVT TruncTy = Op.getValueType(); + SDValue Elem0 = Op0.getOperand(0); + // if we match the low element of the pair, just return it. + if (Elem0.getValueType() == TruncTy) + return Elem0; + // otherwise, if the low part is still too large, apply the truncate. + if (Elem0.getValueType().bitsGT(TruncTy)) + return DCI.DAG.getNode(ISD::TRUNCATE, dl, TruncTy, Elem0); + } + } + + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + if (Opc == HexagonISD::P2D) { SDValue P = Op.getOperand(0); switch (P.getOpcode()) { - case HexagonISD::PTRUE: - return DCI.DAG.getConstant(-1, dl, ty(Op)); - case HexagonISD::PFALSE: - return getZero(dl, ty(Op), DCI.DAG); - default: - break; + case HexagonISD::PTRUE: + return DCI.DAG.getConstant(-1, dl, ty(Op)); + case HexagonISD::PFALSE: + return getZero(dl, ty(Op), DCI.DAG); + default: + break; } } else if (Opc == ISD::VSELECT) { // This is pretty much duplicated in HexagonISelLoweringHVX... @@ -3327,6 +3488,49 @@ HexagonTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) return VSel; } } + } else if (Opc == ISD::TRUNCATE) { + SDValue Op0 = Op.getOperand(0); + // fold (truncate (build pair x, y)) -> (truncate x) or x + if (Op0.getOpcode() == ISD::BUILD_PAIR) { + MVT TruncTy = ty(Op); + SDValue Elem0 = Op0.getOperand(0); + // if we match the low element of the pair, just return it. + if (ty(Elem0) == TruncTy) + return Elem0; + // otherwise, if the low part is still too large, apply the truncate. + if (ty(Elem0).bitsGT(TruncTy)) + return DCI.DAG.getNode(ISD::TRUNCATE, dl, TruncTy, Elem0); + } + } else if (Opc == ISD::OR) { + // fold (or (shl xx, s), (zext y)) -> (COMBINE (shl xx, s-32), y) + // if s >= 32 + auto fold0 = [&, this](SDValue Op) { + if (ty(Op) != MVT::i64) + return SDValue(); + SDValue Shl = Op.getOperand(0); + SDValue Zxt = Op.getOperand(1); + if (Shl.getOpcode() != ISD::SHL) + std::swap(Shl, Zxt); + + if (Shl.getOpcode() != ISD::SHL || Zxt.getOpcode() != ISD::ZERO_EXTEND) + return SDValue(); + + SDValue Z = Zxt.getOperand(0); + auto *Amt = dyn_cast<ConstantSDNode>(Shl.getOperand(1)); + if (Amt && Amt->getZExtValue() >= 32 && ty(Z).getSizeInBits() <= 32) { + unsigned A = Amt->getZExtValue(); + SDValue S = Shl.getOperand(0); + SDValue T0 = DCI.DAG.getNode(ISD::SHL, dl, ty(S), S, + DCI.DAG.getConstant(32 - A, dl, MVT::i32)); + SDValue T1 = DCI.DAG.getZExtOrTrunc(T0, dl, MVT::i32); + SDValue T2 = DCI.DAG.getZExtOrTrunc(Z, dl, MVT::i32); + return DCI.DAG.getNode(HexagonISD::COMBINE, dl, MVT::i64, {T1, T2}); + } + return SDValue(); + }; + + if (SDValue R = fold0(Op)) + return R; } return SDValue(); @@ -3559,7 +3763,7 @@ EVT HexagonTargetLowering::getOptimalMemOpType( bool HexagonTargetLowering::allowsMemoryAccess( LLVMContext &Context, const DataLayout &DL, EVT VT, unsigned AddrSpace, - Align Alignment, MachineMemOperand::Flags Flags, bool *Fast) const { + Align Alignment, MachineMemOperand::Flags Flags, unsigned *Fast) const { MVT SVT = VT.getSimpleVT(); if (Subtarget.isHVXVectorType(SVT, true)) return allowsHvxMemoryAccess(SVT, Flags, Fast); @@ -3569,12 +3773,12 @@ bool HexagonTargetLowering::allowsMemoryAccess( bool HexagonTargetLowering::allowsMisalignedMemoryAccesses( EVT VT, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags, - bool *Fast) const { + unsigned *Fast) const { MVT SVT = VT.getSimpleVT(); if (Subtarget.isHVXVectorType(SVT, true)) return allowsHvxMisalignedMemoryAccesses(SVT, Flags, Fast); if (Fast) - *Fast = false; + *Fast = 0; return false; } @@ -3615,6 +3819,11 @@ bool HexagonTargetLowering::shouldReduceLoadWidth(SDNode *Load, return true; } +void HexagonTargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI, + SDNode *Node) const { + AdjustHvxInstrPostInstrSelection(MI, Node); +} + Value *HexagonTargetLowering::emitLoadLinked(IRBuilderBase &Builder, Type *ValueTy, Value *Addr, AtomicOrdering Ord) const { |