diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 230 |
1 files changed, 165 insertions, 65 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index c2508a158837..03a59f8a8b57 100644 --- a/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1374,8 +1374,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setPrefLoopAlignment(Subtarget.getPrefLoopAlignment()); setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN, - ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND, - ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT}); + ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL, + ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT}); if (Subtarget.is64Bit()) setTargetDAGCombine(ISD::SRA); @@ -7235,7 +7235,7 @@ SDValue RISCVTargetLowering::lowerFRAMEADDR(SDValue Op, EVT VT = Op.getValueType(); SDLoc DL(Op); SDValue FrameAddr = DAG.getCopyFromReg(DAG.getEntryNode(), DL, FrameReg, VT); - unsigned Depth = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue(); + unsigned Depth = Op.getConstantOperandVal(0); while (Depth--) { int Offset = -(XLenInBytes * 2); SDValue Ptr = DAG.getNode(ISD::ADD, DL, VT, FrameAddr, @@ -7260,7 +7260,7 @@ SDValue RISCVTargetLowering::lowerRETURNADDR(SDValue Op, EVT VT = Op.getValueType(); SDLoc DL(Op); - unsigned Depth = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue(); + unsigned Depth = Op.getConstantOperandVal(0); if (Depth) { int Off = -XLenInBytes; SDValue FrameAddr = lowerFRAMEADDR(Op, DAG); @@ -11731,7 +11731,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, break; } case ISD::INTRINSIC_WO_CHAIN: { - unsigned IntNo = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue(); + unsigned IntNo = N->getConstantOperandVal(0); switch (IntNo) { default: llvm_unreachable( @@ -12850,9 +12850,9 @@ struct CombineResult; /// Helper class for folding sign/zero extensions. /// In particular, this class is used for the following combines: -/// add_vl -> vwadd(u) | vwadd(u)_w -/// sub_vl -> vwsub(u) | vwsub(u)_w -/// mul_vl -> vwmul(u) | vwmul_su +/// add | add_vl -> vwadd(u) | vwadd(u)_w +/// sub | sub_vl -> vwsub(u) | vwsub(u)_w +/// mul | mul_vl -> vwmul(u) | vwmul_su /// /// An object of this class represents an operand of the operation we want to /// combine. @@ -12897,6 +12897,8 @@ struct NodeExtensionHelper { /// E.g., for zext(a), this would return a. SDValue getSource() const { switch (OrigOperand.getOpcode()) { + case ISD::ZERO_EXTEND: + case ISD::SIGN_EXTEND: case RISCVISD::VSEXT_VL: case RISCVISD::VZEXT_VL: return OrigOperand.getOperand(0); @@ -12913,7 +12915,8 @@ struct NodeExtensionHelper { /// Get or create a value that can feed \p Root with the given extension \p /// SExt. If \p SExt is std::nullopt, this returns the source of this operand. /// \see ::getSource(). - SDValue getOrCreateExtendedOp(const SDNode *Root, SelectionDAG &DAG, + SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget, std::optional<bool> SExt) const { if (!SExt.has_value()) return OrigOperand; @@ -12928,8 +12931,10 @@ struct NodeExtensionHelper { // If we need an extension, we should be changing the type. SDLoc DL(Root); - auto [Mask, VL] = getMaskAndVL(Root); + auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget); switch (OrigOperand.getOpcode()) { + case ISD::ZERO_EXTEND: + case ISD::SIGN_EXTEND: case RISCVISD::VSEXT_VL: case RISCVISD::VZEXT_VL: return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL); @@ -12969,12 +12974,15 @@ struct NodeExtensionHelper { /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()). static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) { switch (Opcode) { + case ISD::ADD: case RISCVISD::ADD_VL: case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL; + case ISD::MUL: case RISCVISD::MUL_VL: return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL; + case ISD::SUB: case RISCVISD::SUB_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: @@ -12987,7 +12995,8 @@ struct NodeExtensionHelper { /// Get the opcode to materialize \p Opcode(sext(a), zext(b)) -> /// newOpcode(a, b). static unsigned getSUOpcode(unsigned Opcode) { - assert(Opcode == RISCVISD::MUL_VL && "SU is only supported for MUL"); + assert((Opcode == RISCVISD::MUL_VL || Opcode == ISD::MUL) && + "SU is only supported for MUL"); return RISCVISD::VWMULSU_VL; } @@ -12995,8 +13004,10 @@ struct NodeExtensionHelper { /// newOpcode(a, b). static unsigned getWOpcode(unsigned Opcode, bool IsSExt) { switch (Opcode) { + case ISD::ADD: case RISCVISD::ADD_VL: return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL; + case ISD::SUB: case RISCVISD::SUB_VL: return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL; default: @@ -13006,19 +13017,33 @@ struct NodeExtensionHelper { using CombineToTry = std::function<std::optional<CombineResult>( SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/, - const NodeExtensionHelper & /*RHS*/)>; + const NodeExtensionHelper & /*RHS*/, SelectionDAG &, + const RISCVSubtarget &)>; /// Check if this node needs to be fully folded or extended for all users. bool needToPromoteOtherUsers() const { return EnforceOneUse; } /// Helper method to set the various fields of this struct based on the /// type of \p Root. - void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG) { + void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { SupportsZExt = false; SupportsSExt = false; EnforceOneUse = true; CheckMask = true; - switch (OrigOperand.getOpcode()) { + unsigned Opc = OrigOperand.getOpcode(); + switch (Opc) { + case ISD::ZERO_EXTEND: + case ISD::SIGN_EXTEND: { + if (OrigOperand.getValueType().isVector()) { + SupportsZExt = Opc == ISD::ZERO_EXTEND; + SupportsSExt = Opc == ISD::SIGN_EXTEND; + SDLoc DL(Root); + MVT VT = Root->getSimpleValueType(0); + std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget); + } + break; + } case RISCVISD::VZEXT_VL: SupportsZExt = true; Mask = OrigOperand.getOperand(1); @@ -13074,8 +13099,16 @@ struct NodeExtensionHelper { } /// Check if \p Root supports any extension folding combines. - static bool isSupportedRoot(const SDNode *Root) { + static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) { switch (Root->getOpcode()) { + case ISD::ADD: + case ISD::SUB: + case ISD::MUL: { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!TLI.isTypeLegal(Root->getValueType(0))) + return false; + return Root->getValueType(0).isScalableVector(); + } case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: case RISCVISD::VWADD_W_VL: @@ -13090,9 +13123,10 @@ struct NodeExtensionHelper { } /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx). - NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG) { - assert(isSupportedRoot(Root) && "Trying to build an helper with an " - "unsupported root"); + NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an " + "unsupported root"); assert(OperandIdx < 2 && "Requesting something else than LHS or RHS"); OrigOperand = Root->getOperand(OperandIdx); @@ -13108,7 +13142,7 @@ struct NodeExtensionHelper { SupportsZExt = Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL; SupportsSExt = !SupportsZExt; - std::tie(Mask, VL) = getMaskAndVL(Root); + std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget); CheckMask = true; // There's no existing extension here, so we don't have to worry about // making sure it gets removed. @@ -13117,7 +13151,7 @@ struct NodeExtensionHelper { } [[fallthrough]]; default: - fillUpExtensionSupport(Root, DAG); + fillUpExtensionSupport(Root, DAG, Subtarget); break; } } @@ -13133,14 +13167,27 @@ struct NodeExtensionHelper { } /// Helper function to get the Mask and VL from \p Root. - static std::pair<SDValue, SDValue> getMaskAndVL(const SDNode *Root) { - assert(isSupportedRoot(Root) && "Unexpected root"); - return std::make_pair(Root->getOperand(3), Root->getOperand(4)); + static std::pair<SDValue, SDValue> + getMaskAndVL(const SDNode *Root, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(isSupportedRoot(Root, DAG) && "Unexpected root"); + switch (Root->getOpcode()) { + case ISD::ADD: + case ISD::SUB: + case ISD::MUL: { + SDLoc DL(Root); + MVT VT = Root->getSimpleValueType(0); + return getDefaultScalableVLOps(VT, DL, DAG, Subtarget); + } + default: + return std::make_pair(Root->getOperand(3), Root->getOperand(4)); + } } /// Check if the Mask and VL of this operand are compatible with \p Root. - bool areVLAndMaskCompatible(const SDNode *Root) const { - auto [Mask, VL] = getMaskAndVL(Root); + bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) const { + auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget); return isMaskCompatible(Mask) && isVLCompatible(VL); } @@ -13148,11 +13195,14 @@ struct NodeExtensionHelper { /// foldings that are supported by this class. static bool isCommutative(const SDNode *N) { switch (N->getOpcode()) { + case ISD::ADD: + case ISD::MUL: case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: return true; + case ISD::SUB: case RISCVISD::SUB_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: @@ -13197,14 +13247,25 @@ struct CombineResult { /// Return a value that uses TargetOpcode and that can be used to replace /// Root. /// The actual replacement is *not* done in that method. - SDValue materialize(SelectionDAG &DAG) const { + SDValue materialize(SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) const { SDValue Mask, VL, Merge; - std::tie(Mask, VL) = NodeExtensionHelper::getMaskAndVL(Root); - Merge = Root->getOperand(2); + std::tie(Mask, VL) = + NodeExtensionHelper::getMaskAndVL(Root, DAG, Subtarget); + switch (Root->getOpcode()) { + default: + Merge = Root->getOperand(2); + break; + case ISD::ADD: + case ISD::SUB: + case ISD::MUL: + Merge = DAG.getUNDEF(Root->getValueType(0)); + break; + } return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0), - LHS.getOrCreateExtendedOp(Root, DAG, SExtLHS), - RHS.getOrCreateExtendedOp(Root, DAG, SExtRHS), Merge, - Mask, VL); + LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS), + RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS), + Merge, Mask, VL); } }; @@ -13221,15 +13282,16 @@ struct CombineResult { static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, bool AllowSExt, - bool AllowZExt) { + bool AllowZExt, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { assert((AllowSExt || AllowZExt) && "Forgot to set what you want?"); - if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root)) + if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) || + !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget)) return std::nullopt; if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt) return CombineResult(NodeExtensionHelper::getSameExtensionOpcode( Root->getOpcode(), /*IsSExt=*/false), - Root, LHS, /*SExtLHS=*/false, RHS, - /*SExtRHS=*/false); + Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false); if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt) return CombineResult(NodeExtensionHelper::getSameExtensionOpcode( Root->getOpcode(), /*IsSExt=*/true), @@ -13246,9 +13308,10 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS, /// can be used to apply the pattern. static std::optional<CombineResult> canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS, - const NodeExtensionHelper &RHS) { + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true, - /*AllowZExt=*/true); + /*AllowZExt=*/true, DAG, Subtarget); } /// Check if \p Root follows a pattern Root(LHS, ext(RHS)) @@ -13257,8 +13320,9 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS, /// can be used to apply the pattern. static std::optional<CombineResult> canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS, - const NodeExtensionHelper &RHS) { - if (!RHS.areVLAndMaskCompatible(Root)) + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget)) return std::nullopt; // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar @@ -13282,9 +13346,10 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS, /// can be used to apply the pattern. static std::optional<CombineResult> canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS, - const NodeExtensionHelper &RHS) { + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true, - /*AllowZExt=*/false); + /*AllowZExt=*/false, DAG, Subtarget); } /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS)) @@ -13293,9 +13358,10 @@ canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS, /// can be used to apply the pattern. static std::optional<CombineResult> canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS, - const NodeExtensionHelper &RHS) { + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false, - /*AllowZExt=*/true); + /*AllowZExt=*/true, DAG, Subtarget); } /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS)) @@ -13304,10 +13370,13 @@ canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS, /// can be used to apply the pattern. static std::optional<CombineResult> canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS, - const NodeExtensionHelper &RHS) { + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (!LHS.SupportsSExt || !RHS.SupportsZExt) return std::nullopt; - if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root)) + if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) || + !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget)) return std::nullopt; return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()), Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false); @@ -13317,6 +13386,8 @@ SmallVector<NodeExtensionHelper::CombineToTry> NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { SmallVector<CombineToTry> Strategies; switch (Root->getOpcode()) { + case ISD::ADD: + case ISD::SUB: case RISCVISD::ADD_VL: case RISCVISD::SUB_VL: // add|sub -> vwadd(u)|vwsub(u) @@ -13324,6 +13395,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { // add|sub -> vwadd(u)_w|vwsub(u)_w Strategies.push_back(canFoldToVW_W); break; + case ISD::MUL: case RISCVISD::MUL_VL: // mul -> vwmul(u) Strategies.push_back(canFoldToVWWithSameExtension); @@ -13354,12 +13426,14 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { /// mul_vl -> vwmul(u) | vwmul_su /// vwadd_w(u) -> vwadd(u) /// vwub_w(u) -> vwadd(u) -static SDValue -combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { +static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const RISCVSubtarget &Subtarget) { SelectionDAG &DAG = DCI.DAG; - assert(NodeExtensionHelper::isSupportedRoot(N) && - "Shouldn't have called this method"); + if (!NodeExtensionHelper::isSupportedRoot(N, DAG)) + return SDValue(); + SmallVector<SDNode *> Worklist; SmallSet<SDNode *, 8> Inserted; Worklist.push_back(N); @@ -13368,11 +13442,11 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { while (!Worklist.empty()) { SDNode *Root = Worklist.pop_back_val(); - if (!NodeExtensionHelper::isSupportedRoot(Root)) + if (!NodeExtensionHelper::isSupportedRoot(Root, DAG)) return SDValue(); - NodeExtensionHelper LHS(N, 0, DAG); - NodeExtensionHelper RHS(N, 1, DAG); + NodeExtensionHelper LHS(N, 0, DAG, Subtarget); + NodeExtensionHelper RHS(N, 1, DAG, Subtarget); auto AppendUsersIfNeeded = [&Worklist, &Inserted](const NodeExtensionHelper &Op) { if (Op.needToPromoteOtherUsers()) { @@ -13399,7 +13473,8 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { for (NodeExtensionHelper::CombineToTry FoldingStrategy : FoldingStrategies) { - std::optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS); + std::optional<CombineResult> Res = + FoldingStrategy(N, LHS, RHS, DAG, Subtarget); if (Res) { Matched = true; CombinesToApply.push_back(*Res); @@ -13428,7 +13503,7 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { SmallVector<std::pair<SDValue, SDValue>> ValuesToReplace; ValuesToReplace.reserve(CombinesToApply.size()); for (CombineResult Res : CombinesToApply) { - SDValue NewValue = Res.materialize(DAG); + SDValue NewValue = Res.materialize(DAG, Subtarget); if (!InputRootReplacement) { assert(Res.Root == N && "First element is expected to be the current node"); @@ -14078,7 +14153,7 @@ static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG, for (SDNode *U : N0->uses()) { if (U->getOpcode() != ISD::SRA || !isa<ConstantSDNode>(U->getOperand(1)) || - cast<ConstantSDNode>(U->getOperand(1))->getZExtValue() > 32) + U->getConstantOperandVal(1) > 32) return SDValue(); } @@ -14700,13 +14775,20 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - assert(N->getOpcode() == RISCVISD::ADD_VL); + + assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD); + + if (N->getValueType(0).isFixedLengthVector()) + return SDValue(); + SDValue Addend = N->getOperand(0); SDValue MulOp = N->getOperand(1); - SDValue AddMergeOp = N->getOperand(2); - if (!AddMergeOp.isUndef()) - return SDValue(); + if (N->getOpcode() == RISCVISD::ADD_VL) { + SDValue AddMergeOp = N->getOperand(2); + if (!AddMergeOp.isUndef()) + return SDValue(); + } auto IsVWMulOpc = [](unsigned Opc) { switch (Opc) { @@ -14730,8 +14812,16 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG, if (!MulMergeOp.isUndef()) return SDValue(); - SDValue AddMask = N->getOperand(3); - SDValue AddVL = N->getOperand(4); + auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (N->getOpcode() == ISD::ADD) { + SDLoc DL(N); + return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG, + Subtarget); + } + return std::make_pair(N->getOperand(3), N->getOperand(4)); + }(N, DAG, Subtarget); + SDValue MulMask = MulOp.getOperand(3); SDValue MulVL = MulOp.getOperand(4); @@ -14997,10 +15087,18 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return DAG.getNode(ISD::AND, DL, VT, NewFMV, DAG.getConstant(~SignBit, DL, VT)); } - case ISD::ADD: + case ISD::ADD: { + if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) + return V; + if (SDValue V = combineToVWMACC(N, DAG, Subtarget)) + return V; return performADDCombine(N, DAG, Subtarget); - case ISD::SUB: + } + case ISD::SUB: { + if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) + return V; return performSUBCombine(N, DAG, Subtarget); + } case ISD::AND: return performANDCombine(N, DCI, Subtarget); case ISD::OR: @@ -15008,6 +15106,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, case ISD::XOR: return performXORCombine(N, DAG, Subtarget); case ISD::MUL: + if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) + return V; return performMULCombine(N, DAG); case ISD::FADD: case ISD::UMAX: @@ -15484,7 +15584,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, break; } case RISCVISD::ADD_VL: - if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI)) + if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) return V; return combineToVWMACC(N, DAG, Subtarget); case RISCVISD::SUB_VL: @@ -15493,7 +15593,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: case RISCVISD::MUL_VL: - return combineBinOp_VLToVWBinOp_VL(N, DCI); + return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget); case RISCVISD::VFMADD_VL: case RISCVISD::VFNMADD_VL: case RISCVISD::VFMSUB_VL: |
