diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2018-08-02 17:32:43 +0000 | 
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2018-08-02 17:32:43 +0000 | 
| commit | b7eb8e35e481a74962664b63dfb09483b200209a (patch) | |
| tree | 1937fb4a348458ce2d02ade03ac3bb0aa18d2fcd /lib/CodeGen/SelectionDAG/DAGCombiner.cpp | |
| parent | eb11fae6d08f479c0799db45860a98af528fa6e7 (diff) | |
Notes
Diffstat (limited to 'lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
| -rw-r--r-- | lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 231 | 
1 files changed, 193 insertions, 38 deletions
| diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 7a99687757f8..a8c4b85df321 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -72,7 +72,6 @@  #include <string>  #include <tuple>  #include <utility> -#include <vector>  using namespace llvm; @@ -483,9 +482,6 @@ namespace {      /// returns false.      bool findBetterNeighborChains(StoreSDNode *St); -    /// Match "(X shl/srl V1) & V2" where V2 may not be present. -    bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask); -      /// Holds a pointer to an LSBaseSDNode as well as information on where it      /// is located in a sequence of memory operations connected by a chain.      struct MemOpLink { @@ -2671,6 +2667,12 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {      return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),                         N0.getOperand(1).getOperand(0)); +  // fold (A-(B-C)) -> A+(C-B) +  if (N1.getOpcode() == ISD::SUB && N1.hasOneUse()) +    return DAG.getNode(ISD::ADD, DL, VT, N0, +                       DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1), +                                   N1.getOperand(0))); +    // fold (X - (-Y * Z)) -> (X + (Y * Z))    if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {      if (N1.getOperand(0).getOpcode() == ISD::SUB && @@ -2740,6 +2742,17 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {      }    } +  // Prefer an add for more folding potential and possibly better codegen: +  // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1) +  if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) { +    SDValue ShAmt = N1.getOperand(1); +    ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt); +    if (ShAmtC && ShAmtC->getZExtValue() == N1.getScalarValueSizeInBits() - 1) { +      SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt); +      return DAG.getNode(ISD::ADD, DL, VT, N0, SRA); +    } +  } +    return SDValue();  } @@ -4205,8 +4218,8 @@ bool DAGCombiner::SearchForAndLoads(SDNode *N,      // Allow one node which will masked along with any loads found.      if (NodeToMask)        return false; -  -    // Also ensure that the node to be masked only produces one data result.  + +    // Also ensure that the node to be masked only produces one data result.      NodeToMask = Op.getNode();      if (NodeToMask->getNumValues() > 1) {        bool HasValue = false; @@ -5148,25 +5161,140 @@ SDValue DAGCombiner::visitOR(SDNode *N) {    return SDValue();  } -/// Match "(X shl/srl V1) & V2" where V2 may not be present. -bool DAGCombiner::MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) { -  if (Op.getOpcode() == ISD::AND) { -    if (DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) { -      Mask = Op.getOperand(1); -      Op = Op.getOperand(0); -    } else { -      return false; -    } +static SDValue stripConstantMask(SelectionDAG &DAG, SDValue Op, SDValue &Mask) { +  if (Op.getOpcode() == ISD::AND && +      DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) { +    Mask = Op.getOperand(1); +    return Op.getOperand(0);    } +  return Op; +} +/// Match "(X shl/srl V1) & V2" where V2 may not be present. +static bool matchRotateHalf(SelectionDAG &DAG, SDValue Op, SDValue &Shift, +                            SDValue &Mask) { +  Op = stripConstantMask(DAG, Op, Mask);    if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {      Shift = Op;      return true;    } -    return false;  } +/// Helper function for visitOR to extract the needed side of a rotate idiom +/// from a shl/srl/mul/udiv.  This is meant to handle cases where +/// InstCombine merged some outside op with one of the shifts from +/// the rotate pattern. +/// \returns An empty \c SDValue if the needed shift couldn't be extracted. +/// Otherwise, returns an expansion of \p ExtractFrom based on the following +/// patterns: +/// +///   (or (mul v c0) (shrl (mul v c1) c2)): +///     expands (mul v c0) -> (shl (mul v c1) c3) +/// +///   (or (udiv v c0) (shl (udiv v c1) c2)): +///     expands (udiv v c0) -> (shrl (udiv v c1) c3) +/// +///   (or (shl v c0) (shrl (shl v c1) c2)): +///     expands (shl v c0) -> (shl (shl v c1) c3) +/// +///   (or (shrl v c0) (shl (shrl v c1) c2)): +///     expands (shrl v c0) -> (shrl (shrl v c1) c3) +/// +/// Such that in all cases, c3+c2==bitwidth(op v c1). +static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift, +                                     SDValue ExtractFrom, SDValue &Mask, +                                     const SDLoc &DL) { +  assert(OppShift && ExtractFrom && "Empty SDValue"); +  assert( +      (OppShift.getOpcode() == ISD::SHL || OppShift.getOpcode() == ISD::SRL) && +      "Existing shift must be valid as a rotate half"); + +  ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask); +  // Preconditions: +  //    (or (op0 v c0) (shiftl/r (op0 v c1) c2)) +  // +  // Find opcode of the needed shift to be extracted from (op0 v c0). +  unsigned Opcode = ISD::DELETED_NODE; +  bool IsMulOrDiv = false; +  // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift +  // opcode or its arithmetic (mul or udiv) variant. +  auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) { +    IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant; +    if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift) +      return false; +    Opcode = NeededShift; +    return true; +  }; +  // op0 must be either the needed shift opcode or the mul/udiv equivalent +  // that the needed shift can be extracted from. +  if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) && +      (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV))) +    return SDValue(); + +  // op0 must be the same opcode on both sides, have the same LHS argument, +  // and produce the same value type. +  SDValue OppShiftLHS = OppShift.getOperand(0); +  EVT ShiftedVT = OppShiftLHS.getValueType(); +  if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() || +      OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) || +      ShiftedVT != ExtractFrom.getValueType()) +    return SDValue(); + +  // Amount of the existing shift. +  ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1)); +  // Constant mul/udiv/shift amount from the RHS of the shift's LHS op. +  ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1)); +  // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op. +  ConstantSDNode *ExtractFromCst = +      isConstOrConstSplat(ExtractFrom.getOperand(1)); +  // TODO: We should be able to handle non-uniform constant vectors for these values +  // Check that we have constant values. +  if (!OppShiftCst || !OppShiftCst->getAPIntValue() || +      !OppLHSCst || !OppLHSCst->getAPIntValue() || +      !ExtractFromCst || !ExtractFromCst->getAPIntValue()) +    return SDValue(); + +  // Compute the shift amount we need to extract to complete the rotate. +  const unsigned VTWidth = ShiftedVT.getScalarSizeInBits(); +  APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue(); +  if (NeededShiftAmt.isNegative()) +    return SDValue(); +  // Normalize the bitwidth of the two mul/udiv/shift constant operands. +  APInt ExtractFromAmt = ExtractFromCst->getAPIntValue(); +  APInt OppLHSAmt = OppLHSCst->getAPIntValue(); +  zeroExtendToMatch(ExtractFromAmt, OppLHSAmt); + +  // Now try extract the needed shift from the ExtractFrom op and see if the +  // result matches up with the existing shift's LHS op. +  if (IsMulOrDiv) { +    // Op to extract from is a mul or udiv by a constant. +    // Check: +    //     c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0 +    //     c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0 +    const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(), +                                                 NeededShiftAmt.getZExtValue()); +    APInt ResultAmt; +    APInt Rem; +    APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem); +    if (Rem != 0 || ResultAmt != OppLHSAmt) +      return SDValue(); +  } else { +    // Op to extract from is a shift by a constant. +    // Check: +    //      c2 - (bitwidth(op0 v c0) - c1) == c0 +    if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc( +                                          ExtractFromAmt.getBitWidth())) +      return SDValue(); +  } + +  // Return the expanded shift op that should allow a rotate to be formed. +  EVT ShiftVT = OppShift.getOperand(1).getValueType(); +  EVT ResVT = ExtractFrom.getValueType(); +  SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT); +  return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode); +} +  // Return true if we can prove that, whenever Neg and Pos are both in the  // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos).  This means that  // for two opposing shifts shift1 and shift2 and a value X with OpBits bits: @@ -5333,13 +5461,40 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {    // Match "(X shl/srl V1) & V2" where V2 may not be present.    SDValue LHSShift;   // The shift.    SDValue LHSMask;    // AND value if any. -  if (!MatchRotateHalf(LHS, LHSShift, LHSMask)) -    return nullptr; // Not part of a rotate. +  matchRotateHalf(DAG, LHS, LHSShift, LHSMask);    SDValue RHSShift;   // The shift.    SDValue RHSMask;    // AND value if any. -  if (!MatchRotateHalf(RHS, RHSShift, RHSMask)) -    return nullptr; // Not part of a rotate. +  matchRotateHalf(DAG, RHS, RHSShift, RHSMask); + +  // If neither side matched a rotate half, bail +  if (!LHSShift && !RHSShift) +    return nullptr; + +  // InstCombine may have combined a constant shl, srl, mul, or udiv with one +  // side of the rotate, so try to handle that here. In all cases we need to +  // pass the matched shift from the opposite side to compute the opcode and +  // needed shift amount to extract.  We still want to do this if both sides +  // matched a rotate half because one half may be a potential overshift that +  // can be broken down (ie if InstCombine merged two shl or srl ops into a +  // single one). + +  // Have LHS side of the rotate, try to extract the needed shift from the RHS. +  if (LHSShift) +    if (SDValue NewRHSShift = +            extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL)) +      RHSShift = NewRHSShift; +  // Have RHS side of the rotate, try to extract the needed shift from the LHS. +  if (RHSShift) +    if (SDValue NewLHSShift = +            extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL)) +      LHSShift = NewLHSShift; + +  // If a side is still missing, nothing else we can do. +  if (!RHSShift || !LHSShift) +    return nullptr; + +  // At this point we've matched or extracted a shift op on each side.    if (LHSShift.getOperand(0) != RHSShift.getOperand(0))      return nullptr;   // Not shifting the same value. @@ -10270,7 +10425,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {                                                   N10.getOperand(0))),                           DAG.getNode(ISD::FP_EXTEND, SL, VT,                                       N10.getOperand(1)), -                         N0, Flags);           +                         N0, Flags);      }    } @@ -10333,7 +10488,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {                                       N0.getOperand(2).getOperand(0),                                       N0.getOperand(2).getOperand(1),                                       DAG.getNode(ISD::FNEG, SL, VT, -                                                 N1), Flags), Flags);           +                                                 N1), Flags), Flags);      }      // fold (fsub x, (fma y, z, (fmul u, v))) @@ -10348,7 +10503,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {                           N1.getOperand(1),                           DAG.getNode(PreferredFusedOpcode, SL, VT,                                       DAG.getNode(ISD::FNEG, SL, VT, N20), -                                     N21, N0, Flags), Flags);       +                                     N21, N0, Flags), Flags);      } @@ -10368,7 +10523,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {                                           DAG.getNode(ISD::FP_EXTEND, SL, VT,                                                       N020.getOperand(1)),                                           DAG.getNode(ISD::FNEG, SL, VT, -                                                     N1), Flags), Flags);               +                                                     N1), Flags), Flags);          }        }      } @@ -10396,7 +10551,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {                                           DAG.getNode(ISD::FP_EXTEND, SL, VT,                                                       N002.getOperand(1)),                                           DAG.getNode(ISD::FNEG, SL, VT, -                                                     N1), Flags), Flags);               +                                                     N1), Flags), Flags);          }        }      } @@ -10419,7 +10574,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {                                                                 VT, N1200)),                                         DAG.getNode(ISD::FP_EXTEND, SL, VT,                                                     N1201), -                                       N0, Flags), Flags);         +                                       N0, Flags), Flags);        }      } @@ -10450,7 +10605,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {                                                                 VT, N1020)),                                         DAG.getNode(ISD::FP_EXTEND, SL, VT,                                                     N1021), -                                       N0, Flags), Flags);         +                                       N0, Flags), Flags);        }      }    } @@ -10506,7 +10661,7 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {                             Y, Flags);        if (XC1 && XC1->isExactlyValue(-1.0))          return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, -                           DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);       +                           DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);      }      return SDValue();    }; @@ -10530,7 +10685,7 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {        if (XC0 && XC0->isExactlyValue(-1.0))          return DAG.getNode(PreferredFusedOpcode, SL, VT,                             DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, -                           DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);       +                           DAG.getNode(ISD::FNEG, SL, VT, Y), Flags);        auto XC1 = isConstOrConstSplatFP(X.getOperand(1));        if (XC1 && XC1->isExactlyValue(+1.0)) @@ -10838,12 +10993,12 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {    if (SDValue NewSel = foldBinOpIntoSelect(N))      return NewSel; -  if (Options.UnsafeFPMath ||  +  if (Options.UnsafeFPMath ||        (Flags.hasNoNaNs() && Flags.hasNoSignedZeros())) {      // fold (fmul A, 0) -> 0      if (N1CFP && N1CFP->isZero())        return N1; -  }  +  }    if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {      // fmul (fmul X, C1), C2 -> fmul X, C1 * C2 @@ -11258,7 +11413,7 @@ SDValue DAGCombiner::visitFREM(SDNode *N) {  SDValue DAGCombiner::visitFSQRT(SDNode *N) {    SDNodeFlags Flags = N->getFlags(); -  if (!DAG.getTarget().Options.UnsafeFPMath &&  +  if (!DAG.getTarget().Options.UnsafeFPMath &&        !Flags.hasApproximateFuncs())      return SDValue(); @@ -17913,9 +18068,9 @@ SDValue DAGCombiner::BuildSDIV(SDNode *N) {    if (C->isNullValue())      return SDValue(); -  std::vector<SDNode *> Built; +  SmallVector<SDNode *, 8> Built;    SDValue S = -      TLI.BuildSDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built); +      TLI.BuildSDIV(N, C->getAPIntValue(), DAG, LegalOperations, Built);    for (SDNode *N : Built)      AddToWorklist(N); @@ -17933,8 +18088,8 @@ SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {    if (C->isNullValue())      return SDValue(); -  std::vector<SDNode *> Built; -  SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, &Built); +  SmallVector<SDNode *, 8> Built; +  SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built);    for (SDNode *N : Built)      AddToWorklist(N); @@ -17959,9 +18114,9 @@ SDValue DAGCombiner::BuildUDIV(SDNode *N) {    if (C->isNullValue())      return SDValue(); -  std::vector<SDNode *> Built; +  SmallVector<SDNode *, 8> Built;    SDValue S = -      TLI.BuildUDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built); +      TLI.BuildUDIV(N, C->getAPIntValue(), DAG, LegalOperations, Built);    for (SDNode *N : Built)      AddToWorklist(N); | 
