diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 131 |
1 files changed, 123 insertions, 8 deletions
diff --git a/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index ce400ea43f29..df5a041b87cd 100644 --- a/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -4436,7 +4436,7 @@ SDValue DAGCombiner::visitREM(SDNode *N) { if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) { SDValue OptimizedDiv = isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N); - if (OptimizedDiv.getNode()) { + if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) { // If the equivalent Div node also exists, update its users. unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV; if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(), @@ -4464,6 +4464,9 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) { SDLoc DL(N); if (VT.isVector()) { + if (SDValue FoldedVOp = SimplifyVBinOp(N, DL)) + return FoldedVOp; + // fold (mulhs x, 0) -> 0 // do not return N0/N1, because undef node may exist. if (ISD::isConstantSplatVectorAllZeros(N0.getNode()) || @@ -4521,6 +4524,9 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) { SDLoc DL(N); if (VT.isVector()) { + if (SDValue FoldedVOp = SimplifyVBinOp(N, DL)) + return FoldedVOp; + // fold (mulhu x, 0) -> 0 // do not return N0/N1, because undef node may exist. if (ISD::isConstantSplatVectorAllZeros(N0.getNode()) || @@ -4779,6 +4785,106 @@ SDValue DAGCombiner::visitMULO(SDNode *N) { return SDValue(); } +// Function to calculate whether the Min/Max pair of SDNodes (potentially +// swapped around) make a signed saturate pattern, clamping to between -2^(BW-1) +// and 2^(BW-1)-1. Returns the node being clamped and the bitwidth of the clamp +// in BW. Should work with both SMIN/SMAX nodes and setcc/select combo. The +// operands are the same as SimplifySelectCC. N0<N1 ? N2 : N3 +static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2, + SDValue N3, ISD::CondCode CC, unsigned &BW) { + auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3, + ISD::CondCode CC) { + // The compare and select operand should be the same or the select operands + // should be truncated versions of the comparison. + if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0))) + return 0; + // The constants need to be the same or a truncated version of each other. + ConstantSDNode *N1C = isConstOrConstSplat(N1); + ConstantSDNode *N3C = isConstOrConstSplat(N3); + if (!N1C || !N3C) + return 0; + const APInt &C1 = N1C->getAPIntValue(); + const APInt &C2 = N3C->getAPIntValue(); + if (C1.getBitWidth() < C2.getBitWidth() || + C1 != C2.sextOrSelf(C1.getBitWidth())) + return 0; + return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0); + }; + + // Check the initial value is a SMIN/SMAX equivalent. + unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC); + if (!Opcode0) + return SDValue(); + + SDValue N00, N01, N02, N03; + ISD::CondCode N0CC; + switch (N0.getOpcode()) { + case ISD::SMIN: + case ISD::SMAX: + N00 = N02 = N0.getOperand(0); + N01 = N03 = N0.getOperand(1); + N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT; + break; + case ISD::SELECT_CC: + N00 = N0.getOperand(0); + N01 = N0.getOperand(1); + N02 = N0.getOperand(2); + N03 = N0.getOperand(3); + N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get(); + break; + case ISD::SELECT: + case ISD::VSELECT: + if (N0.getOperand(0).getOpcode() != ISD::SETCC) + return SDValue(); + N00 = N0.getOperand(0).getOperand(0); + N01 = N0.getOperand(0).getOperand(1); + N02 = N0.getOperand(1); + N03 = N0.getOperand(2); + N0CC = cast<CondCodeSDNode>(N0.getOperand(0).getOperand(2))->get(); + break; + default: + return SDValue(); + } + + unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC); + if (!Opcode1 || Opcode0 == Opcode1) + return SDValue(); + + ConstantSDNode *MinCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N1 : N01); + ConstantSDNode *MaxCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N01 : N1); + if (!MinCOp || !MaxCOp || MinCOp->getValueType(0) != MaxCOp->getValueType(0)) + return SDValue(); + + const APInt &MinC = MinCOp->getAPIntValue(); + const APInt &MaxC = MaxCOp->getAPIntValue(); + APInt MinCPlus1 = MinC + 1; + if (-MaxC != MinCPlus1 || !MinCPlus1.isPowerOf2()) + return SDValue(); + BW = MinCPlus1.exactLogBase2() + 1; + return N02; +} + +static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2, + SDValue N3, ISD::CondCode CC, + SelectionDAG &DAG) { + unsigned BW; + SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW); + if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT) + return SDValue(); + EVT FPVT = Fp.getOperand(0).getValueType(); + EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW); + if (FPVT.isVector()) + NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT, + FPVT.getVectorElementCount()); + if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat( + ISD::FP_TO_SINT_SAT, Fp.getOperand(0).getValueType(), NewVT)) + return SDValue(); + SDLoc DL(Fp); + SDValue Sat = DAG.getNode(ISD::FP_TO_SINT_SAT, DL, NewVT, Fp.getOperand(0), + DAG.getValueType(NewVT.getScalarType())); + return DAG.getSExtOrTrunc(Sat, DL, N2->getValueType(0)); +} + SDValue DAGCombiner::visitIMINMAX(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -4817,6 +4923,11 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) { return DAG.getNode(AltOpcode, DL, VT, N0, N1); } + if (Opcode == ISD::SMIN || Opcode == ISD::SMAX) + if (SDValue S = PerformMinMaxFpToSatCombine( + N0, N1, N0, N1, Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG)) + return S; + // Simplify the operands using demanded-bits information. if (SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); @@ -9940,9 +10051,8 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) { // If this is a masked load with an all ones mask, we can use a unmasked load. // FIXME: Can we do this for indexed, compressing, or truncating stores? - if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && - MST->isUnindexed() && !MST->isCompressingStore() && - !MST->isTruncatingStore()) + if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MST->isUnindexed() && + !MST->isCompressingStore() && !MST->isTruncatingStore()) return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(), MST->getBasePtr(), MST->getMemOperand()); @@ -9997,9 +10107,8 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) { // If this is a masked load with an all ones mask, we can use a unmasked load. // FIXME: Can we do this for indexed, expanding, or extending loads? - if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && - MLD->isUnindexed() && !MLD->isExpandingLoad() && - MLD->getExtensionType() == ISD::NON_EXTLOAD) { + if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MLD->isUnindexed() && + !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) { SDValue NewLd = DAG.getLoad(N->getValueType(0), SDLoc(N), MLD->getChain(), MLD->getBasePtr(), MLD->getMemOperand()); return CombineTo(N, NewLd, NewLd.getValue(1)); @@ -10138,6 +10247,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { return FMinMax; } + if (SDValue S = PerformMinMaxFpToSatCombine(LHS, RHS, N1, N2, CC, DAG)) + return S; + // If this select has a condition (setcc) with narrower operands than the // select, try to widen the compare to match the select width. // TODO: This should be extended to handle any constant. @@ -15007,7 +15119,7 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) { // fold (fpext (load x)) -> (fpext (fptrunc (extload x))) if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() && - TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) { + TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) { LoadSDNode *LN0 = cast<LoadSDNode>(N0); SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT, LN0->getChain(), @@ -23034,6 +23146,9 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, DAG.getSExtOrTrunc(CC == ISD::SETLT ? N3 : N2, DL, VT)); } + if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG)) + return S; + return SDValue(); } |
