aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp131
1 files changed, 123 insertions, 8 deletions
diff --git a/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ce400ea43f29..df5a041b87cd 100644
--- a/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4436,7 +4436,7 @@ SDValue DAGCombiner::visitREM(SDNode *N) {
if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
SDValue OptimizedDiv =
isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
- if (OptimizedDiv.getNode()) {
+ if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
// If the equivalent Div node also exists, update its users.
unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
@@ -4464,6 +4464,9 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) {
SDLoc DL(N);
if (VT.isVector()) {
+ if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
+ return FoldedVOp;
+
// fold (mulhs x, 0) -> 0
// do not return N0/N1, because undef node may exist.
if (ISD::isConstantSplatVectorAllZeros(N0.getNode()) ||
@@ -4521,6 +4524,9 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) {
SDLoc DL(N);
if (VT.isVector()) {
+ if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
+ return FoldedVOp;
+
// fold (mulhu x, 0) -> 0
// do not return N0/N1, because undef node may exist.
if (ISD::isConstantSplatVectorAllZeros(N0.getNode()) ||
@@ -4779,6 +4785,106 @@ SDValue DAGCombiner::visitMULO(SDNode *N) {
return SDValue();
}
+// Function to calculate whether the Min/Max pair of SDNodes (potentially
+// swapped around) make a signed saturate pattern, clamping to between -2^(BW-1)
+// and 2^(BW-1)-1. Returns the node being clamped and the bitwidth of the clamp
+// in BW. Should work with both SMIN/SMAX nodes and setcc/select combo. The
+// operands are the same as SimplifySelectCC. N0<N1 ? N2 : N3
+static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
+ SDValue N3, ISD::CondCode CC, unsigned &BW) {
+ auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
+ ISD::CondCode CC) {
+ // The compare and select operand should be the same or the select operands
+ // should be truncated versions of the comparison.
+ if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0)))
+ return 0;
+ // The constants need to be the same or a truncated version of each other.
+ ConstantSDNode *N1C = isConstOrConstSplat(N1);
+ ConstantSDNode *N3C = isConstOrConstSplat(N3);
+ if (!N1C || !N3C)
+ return 0;
+ const APInt &C1 = N1C->getAPIntValue();
+ const APInt &C2 = N3C->getAPIntValue();
+ if (C1.getBitWidth() < C2.getBitWidth() ||
+ C1 != C2.sextOrSelf(C1.getBitWidth()))
+ return 0;
+ return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
+ };
+
+ // Check the initial value is a SMIN/SMAX equivalent.
+ unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
+ if (!Opcode0)
+ return SDValue();
+
+ SDValue N00, N01, N02, N03;
+ ISD::CondCode N0CC;
+ switch (N0.getOpcode()) {
+ case ISD::SMIN:
+ case ISD::SMAX:
+ N00 = N02 = N0.getOperand(0);
+ N01 = N03 = N0.getOperand(1);
+ N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
+ break;
+ case ISD::SELECT_CC:
+ N00 = N0.getOperand(0);
+ N01 = N0.getOperand(1);
+ N02 = N0.getOperand(2);
+ N03 = N0.getOperand(3);
+ N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
+ break;
+ case ISD::SELECT:
+ case ISD::VSELECT:
+ if (N0.getOperand(0).getOpcode() != ISD::SETCC)
+ return SDValue();
+ N00 = N0.getOperand(0).getOperand(0);
+ N01 = N0.getOperand(0).getOperand(1);
+ N02 = N0.getOperand(1);
+ N03 = N0.getOperand(2);
+ N0CC = cast<CondCodeSDNode>(N0.getOperand(0).getOperand(2))->get();
+ break;
+ default:
+ return SDValue();
+ }
+
+ unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
+ if (!Opcode1 || Opcode0 == Opcode1)
+ return SDValue();
+
+ ConstantSDNode *MinCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N1 : N01);
+ ConstantSDNode *MaxCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N01 : N1);
+ if (!MinCOp || !MaxCOp || MinCOp->getValueType(0) != MaxCOp->getValueType(0))
+ return SDValue();
+
+ const APInt &MinC = MinCOp->getAPIntValue();
+ const APInt &MaxC = MaxCOp->getAPIntValue();
+ APInt MinCPlus1 = MinC + 1;
+ if (-MaxC != MinCPlus1 || !MinCPlus1.isPowerOf2())
+ return SDValue();
+ BW = MinCPlus1.exactLogBase2() + 1;
+ return N02;
+}
+
+static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
+ SDValue N3, ISD::CondCode CC,
+ SelectionDAG &DAG) {
+ unsigned BW;
+ SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW);
+ if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
+ return SDValue();
+ EVT FPVT = Fp.getOperand(0).getValueType();
+ EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
+ if (FPVT.isVector())
+ NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
+ FPVT.getVectorElementCount());
+ if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(
+ ISD::FP_TO_SINT_SAT, Fp.getOperand(0).getValueType(), NewVT))
+ return SDValue();
+ SDLoc DL(Fp);
+ SDValue Sat = DAG.getNode(ISD::FP_TO_SINT_SAT, DL, NewVT, Fp.getOperand(0),
+ DAG.getValueType(NewVT.getScalarType()));
+ return DAG.getSExtOrTrunc(Sat, DL, N2->getValueType(0));
+}
+
SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -4817,6 +4923,11 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
return DAG.getNode(AltOpcode, DL, VT, N0, N1);
}
+ if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
+ if (SDValue S = PerformMinMaxFpToSatCombine(
+ N0, N1, N0, N1, Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
+ return S;
+
// Simplify the operands using demanded-bits information.
if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
@@ -9940,9 +10051,8 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) {
// If this is a masked load with an all ones mask, we can use a unmasked load.
// FIXME: Can we do this for indexed, compressing, or truncating stores?
- if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) &&
- MST->isUnindexed() && !MST->isCompressingStore() &&
- !MST->isTruncatingStore())
+ if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MST->isUnindexed() &&
+ !MST->isCompressingStore() && !MST->isTruncatingStore())
return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
MST->getBasePtr(), MST->getMemOperand());
@@ -9997,9 +10107,8 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) {
// If this is a masked load with an all ones mask, we can use a unmasked load.
// FIXME: Can we do this for indexed, expanding, or extending loads?
- if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) &&
- MLD->isUnindexed() && !MLD->isExpandingLoad() &&
- MLD->getExtensionType() == ISD::NON_EXTLOAD) {
+ if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MLD->isUnindexed() &&
+ !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
SDValue NewLd = DAG.getLoad(N->getValueType(0), SDLoc(N), MLD->getChain(),
MLD->getBasePtr(), MLD->getMemOperand());
return CombineTo(N, NewLd, NewLd.getValue(1));
@@ -10138,6 +10247,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
return FMinMax;
}
+ if (SDValue S = PerformMinMaxFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
+ return S;
+
// If this select has a condition (setcc) with narrower operands than the
// select, try to widen the compare to match the select width.
// TODO: This should be extended to handle any constant.
@@ -15007,7 +15119,7 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
// fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
- TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
+ TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) {
LoadSDNode *LN0 = cast<LoadSDNode>(N0);
SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
LN0->getChain(),
@@ -23034,6 +23146,9 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
DAG.getSExtOrTrunc(CC == ISD::SETLT ? N3 : N2, DL, VT));
}
+ if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
+ return S;
+
return SDValue();
}