diff options
Diffstat (limited to 'lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
| -rw-r--r-- | lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 3143 |
1 files changed, 1927 insertions, 1216 deletions
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index a8c4b85df321..ff5505c97721 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" @@ -83,6 +84,7 @@ STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created"); STATISTIC(OpsNarrowed , "Number of load/op/store narrowed"); STATISTIC(LdStFP2Int , "Number of fp load/store pairs transformed to int"); STATISTIC(SlicedLoads, "Number of load sliced"); +STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops"); static cl::opt<bool> CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden, @@ -249,6 +251,11 @@ namespace { SDValue SplitIndexingFromLoad(LoadSDNode *LD); bool SliceUpLoad(SDNode *N); + // Scalars have size 0 to distinguish from singleton vectors. + SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD); + bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val); + bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val); + /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed /// load. /// @@ -257,8 +264,9 @@ namespace { /// \param EltNo index of the vector element to load. /// \param OriginalLoad load that EVE came from to be replaced. /// \returns EVE on success SDValue() on failure. - SDValue ReplaceExtractVectorEltOfLoadWithNarrowedLoad( - SDNode *EVE, EVT InVecVT, SDValue EltNo, LoadSDNode *OriginalLoad); + SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT, + SDValue EltNo, + LoadSDNode *OriginalLoad); void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad); SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace); SDValue SExtPromoteOperand(SDValue Op, EVT PVT); @@ -285,6 +293,8 @@ namespace { SDValue visitADD(SDNode *N); SDValue visitADDLike(SDValue N0, SDValue N1, SDNode *LocReference); SDValue visitSUB(SDNode *N); + SDValue visitADDSAT(SDNode *N); + SDValue visitSUBSAT(SDNode *N); SDValue visitADDC(SDNode *N); SDValue visitUADDO(SDNode *N); SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N); @@ -318,6 +328,7 @@ namespace { SDValue visitSHL(SDNode *N); SDValue visitSRA(SDNode *N); SDValue visitSRL(SDNode *N); + SDValue visitFunnelShift(SDNode *N); SDValue visitRotate(SDNode *N); SDValue visitABS(SDNode *N); SDValue visitBSWAP(SDNode *N); @@ -350,6 +361,7 @@ namespace { SDValue visitFREM(SDNode *N); SDValue visitFSQRT(SDNode *N); SDValue visitFCOPYSIGN(SDNode *N); + SDValue visitFPOW(SDNode *N); SDValue visitSINT_TO_FP(SDNode *N); SDValue visitUINT_TO_FP(SDNode *N); SDValue visitFP_TO_SINT(SDNode *N); @@ -364,6 +376,8 @@ namespace { SDValue visitFFLOOR(SDNode *N); SDValue visitFMINNUM(SDNode *N); SDValue visitFMAXNUM(SDNode *N); + SDValue visitFMINIMUM(SDNode *N); + SDValue visitFMAXIMUM(SDNode *N); SDValue visitBRCOND(SDNode *N); SDValue visitBR_CC(SDNode *N); SDValue visitLOAD(SDNode *N); @@ -393,7 +407,7 @@ namespace { SDValue XformToShuffleWithZero(SDNode *N); SDValue ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, - SDValue N1); + SDValue N1, SDNodeFlags Flags); SDValue visitShiftByConstant(SDNode *N, ConstantSDNode *Amt); @@ -401,11 +415,14 @@ namespace { SDValue foldVSelectOfConstants(SDNode *N); SDValue foldBinOpIntoSelect(SDNode *BO); bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS); - SDValue SimplifyBinOpWithSameOpcodeHands(SDNode *N); + SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N); SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2); SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3, ISD::CondCode CC, bool NotExtCompare = false); + SDValue convertSelectOfFPConstantsToLoadOffset( + const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3, + ISD::CondCode CC); SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3, ISD::CondCode CC); SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1, @@ -455,7 +472,6 @@ namespace { SDValue TransformFPLoadStorePair(SDNode *N); SDValue convertBuildVecZextToZext(SDNode *N); SDValue reduceBuildVecExtToExtBuildVec(SDNode *N); - SDValue reduceBuildVecConvertToConvertBuildVec(SDNode *N); SDValue reduceBuildVecToShuffle(SDNode *N); SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N, ArrayRef<int> VectorMask, SDValue VecIn1, @@ -482,6 +498,10 @@ namespace { /// returns false. bool findBetterNeighborChains(StoreSDNode *St); + // Helper for findBetterNeighborChains. Walk up store chain add additional + // chained stores that do not overlap and can be parallelized. + bool parallelizeChainedStores(StoreSDNode *St); + /// Holds a pointer to an LSBaseSDNode as well as information on where it /// is located in a sequence of memory operations connected by a chain. struct MemOpLink { @@ -515,7 +535,7 @@ namespace { EVT &MemVT, unsigned ShAmt = 0); /// Used by BackwardsPropagateMask to find suitable loads. - bool SearchForAndLoads(SDNode *N, SmallPtrSetImpl<LoadSDNode*> &Loads, + bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads, SmallPtrSetImpl<SDNode*> &NodesWithConsts, ConstantSDNode *Mask, SDNode *&NodeToMask); /// Attempt to propagate a given AND node back to load leaves so that they @@ -865,12 +885,6 @@ bool DAGCombiner::isOneUseSetCC(SDValue N) const { return false; } -static SDValue peekThroughBitcast(SDValue V) { - while (V.getOpcode() == ISD::BITCAST) - V = V.getOperand(0); - return V; -} - // Returns the SDNode if it is a constant float BuildVector // or constant float. static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) { @@ -901,50 +915,23 @@ static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) { return true; } -// Determines if it is a constant null integer or a splatted vector of a -// constant null integer (with no undefs). -// Build vector implicit truncation is not an issue for null values. -static bool isNullConstantOrNullSplatConstant(SDValue N) { - // TODO: may want to use peekThroughBitcast() here. - if (ConstantSDNode *Splat = isConstOrConstSplat(N)) - return Splat->isNullValue(); - return false; -} - -// Determines if it is a constant integer of one or a splatted vector of a -// constant integer of one (with no undefs). -// Do not permit build vector implicit truncation. -static bool isOneConstantOrOneSplatConstant(SDValue N) { - // TODO: may want to use peekThroughBitcast() here. - unsigned BitWidth = N.getScalarValueSizeInBits(); - if (ConstantSDNode *Splat = isConstOrConstSplat(N)) - return Splat->isOne() && Splat->getAPIntValue().getBitWidth() == BitWidth; - return false; -} - -// Determines if it is a constant integer of all ones or a splatted vector of a -// constant integer of all ones (with no undefs). -// Do not permit build vector implicit truncation. -static bool isAllOnesConstantOrAllOnesSplatConstant(SDValue N) { - N = peekThroughBitcast(N); - unsigned BitWidth = N.getScalarValueSizeInBits(); - if (ConstantSDNode *Splat = isConstOrConstSplat(N)) - return Splat->isAllOnesValue() && - Splat->getAPIntValue().getBitWidth() == BitWidth; - return false; -} - // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with // undef's. -static bool isAnyConstantBuildVector(const SDNode *N) { - return ISD::isBuildVectorOfConstantSDNodes(N) || - ISD::isBuildVectorOfConstantFPSDNodes(N); +static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) { + if (V.getOpcode() != ISD::BUILD_VECTOR) + return false; + return isConstantOrConstantVector(V, NoOpaques) || + ISD::isBuildVectorOfConstantFPSDNodes(V.getNode()); } SDValue DAGCombiner::ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, - SDValue N1) { + SDValue N1, SDNodeFlags Flags) { + // Don't reassociate reductions. + if (Flags.hasVectorReduction()) + return SDValue(); + EVT VT = N0.getValueType(); - if (N0.getOpcode() == Opc) { + if (N0.getOpcode() == Opc && !N0->getFlags().hasVectorReduction()) { if (SDNode *L = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1))) { if (SDNode *R = DAG.isConstantIntBuildVectorOrConstantInt(N1)) { // reassoc. (op (op x, c1), c2) -> (op x, (op c1, c2)) @@ -964,7 +951,7 @@ SDValue DAGCombiner::ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, } } - if (N1.getOpcode() == Opc) { + if (N1.getOpcode() == Opc && !N1->getFlags().hasVectorReduction()) { if (SDNode *R = DAG.isConstantIntBuildVectorOrConstantInt(N1.getOperand(1))) { if (SDNode *L = DAG.isConstantIntBuildVectorOrConstantInt(N0)) { // reassoc. (op c2, (op x, c1)) -> (op x, (op c1, c2)) @@ -1501,6 +1488,10 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::MERGE_VALUES: return visitMERGE_VALUES(N); case ISD::ADD: return visitADD(N); case ISD::SUB: return visitSUB(N); + case ISD::SADDSAT: + case ISD::UADDSAT: return visitADDSAT(N); + case ISD::SSUBSAT: + case ISD::USUBSAT: return visitSUBSAT(N); case ISD::ADDC: return visitADDC(N); case ISD::UADDO: return visitUADDO(N); case ISD::SUBC: return visitSUBC(N); @@ -1532,6 +1523,8 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::SRL: return visitSRL(N); case ISD::ROTR: case ISD::ROTL: return visitRotate(N); + case ISD::FSHL: + case ISD::FSHR: return visitFunnelShift(N); case ISD::ABS: return visitABS(N); case ISD::BSWAP: return visitBSWAP(N); case ISD::BITREVERSE: return visitBITREVERSE(N); @@ -1564,6 +1557,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::FREM: return visitFREM(N); case ISD::FSQRT: return visitFSQRT(N); case ISD::FCOPYSIGN: return visitFCOPYSIGN(N); + case ISD::FPOW: return visitFPOW(N); case ISD::SINT_TO_FP: return visitSINT_TO_FP(N); case ISD::UINT_TO_FP: return visitUINT_TO_FP(N); case ISD::FP_TO_SINT: return visitFP_TO_SINT(N); @@ -1576,6 +1570,8 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::FFLOOR: return visitFFLOOR(N); case ISD::FMINNUM: return visitFMINNUM(N); case ISD::FMAXNUM: return visitFMAXNUM(N); + case ISD::FMINIMUM: return visitFMINIMUM(N); + case ISD::FMAXIMUM: return visitFMAXIMUM(N); case ISD::FCEIL: return visitFCEIL(N); case ISD::FTRUNC: return visitFTRUNC(N); case ISD::BRCOND: return visitBRCOND(N); @@ -1855,8 +1851,11 @@ SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) { // can be tried again once they have new operands. AddUsersToWorklist(N); do { + // Do as a single replacement to avoid rewalking use lists. + SmallVector<SDValue, 8> Ops; for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) - DAG.ReplaceAllUsesOfValueWith(SDValue(N, i), N->getOperand(i)); + Ops.push_back(N->getOperand(i)); + DAG.ReplaceAllUsesWith(N, Ops.data()); } while (!N->use_empty()); deleteAndRecombine(N); return SDValue(N, 0); // Return N so it doesn't get rechecked! @@ -1870,17 +1869,7 @@ static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) { } SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) { - auto BinOpcode = BO->getOpcode(); - assert((BinOpcode == ISD::ADD || BinOpcode == ISD::SUB || - BinOpcode == ISD::MUL || BinOpcode == ISD::SDIV || - BinOpcode == ISD::UDIV || BinOpcode == ISD::SREM || - BinOpcode == ISD::UREM || BinOpcode == ISD::AND || - BinOpcode == ISD::OR || BinOpcode == ISD::XOR || - BinOpcode == ISD::SHL || BinOpcode == ISD::SRL || - BinOpcode == ISD::SRA || BinOpcode == ISD::FADD || - BinOpcode == ISD::FSUB || BinOpcode == ISD::FMUL || - BinOpcode == ISD::FDIV || BinOpcode == ISD::FREM) && - "Unexpected binary operator"); + assert(ISD::isBinaryOp(BO) && "Unexpected binary operator"); // Don't do this unless the old select is going away. We want to eliminate the // binary operator, not replace a binop with a select. @@ -1910,11 +1899,11 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) { // propagate non constant operands into select. I.e.: // and (select Cond, 0, -1), X --> select Cond, 0, X // or X, (select Cond, -1, 0) --> select Cond, -1, X - bool CanFoldNonConst = (BinOpcode == ISD::AND || BinOpcode == ISD::OR) && - (isNullConstantOrNullSplatConstant(CT) || - isAllOnesConstantOrAllOnesSplatConstant(CT)) && - (isNullConstantOrNullSplatConstant(CF) || - isAllOnesConstantOrAllOnesSplatConstant(CF)); + auto BinOpcode = BO->getOpcode(); + bool CanFoldNonConst = + (BinOpcode == ISD::AND || BinOpcode == ISD::OR) && + (isNullOrNullSplat(CT) || isAllOnesOrAllOnesSplat(CT)) && + (isNullOrNullSplat(CF) || isAllOnesOrAllOnesSplat(CF)); SDValue CBO = BO->getOperand(SelOpNo ^ 1); if (!CanFoldNonConst && @@ -2009,10 +1998,8 @@ static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) { return SDValue(); // The shift must be of a 'not' value. - // TODO: Use isBitwiseNot() if it works with vectors. SDValue Not = ShiftOp.getOperand(0); - if (!Not.hasOneUse() || Not.getOpcode() != ISD::XOR || - !isAllOnesConstantOrAllOnesSplatConstant(Not.getOperand(1))) + if (!Not.hasOneUse() || !isBitwiseNot(Not)) return SDValue(); // The shift must be moving the sign bit to the least-significant-bit. @@ -2085,7 +2072,7 @@ SDValue DAGCombiner::visitADD(SDNode *N) { // add (zext i1 X), -1 -> sext (not i1 X) // because most (?) targets generate better code for the zext form. if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() && - isOneConstantOrOneSplatConstant(N1)) { + isOneOrOneSplat(N1)) { SDValue X = N0.getOperand(0); if ((!LegalOperations || (TLI.isOperationLegal(ISD::XOR, X.getValueType()) && @@ -2110,17 +2097,15 @@ SDValue DAGCombiner::visitADD(SDNode *N) { return NewSel; // reassociate add - if (SDValue RADD = ReassociateOps(ISD::ADD, DL, N0, N1)) + if (SDValue RADD = ReassociateOps(ISD::ADD, DL, N0, N1, N->getFlags())) return RADD; // fold ((0-A) + B) -> B-A - if (N0.getOpcode() == ISD::SUB && - isNullConstantOrNullSplatConstant(N0.getOperand(0))) + if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0))) return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1)); // fold (A + (0-B)) -> A-B - if (N1.getOpcode() == ISD::SUB && - isNullConstantOrNullSplatConstant(N1.getOperand(0))) + if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0))) return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1)); // fold (A+(B-A)) -> B @@ -2178,7 +2163,7 @@ SDValue DAGCombiner::visitADD(SDNode *N) { return DAG.getNode(ISD::OR, DL, VT, N0, N1); // fold (add (xor a, -1), 1) -> (sub 0, a) - if (isBitwiseNot(N0) && isOneConstantOrOneSplatConstant(N1)) + if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0.getOperand(0)); @@ -2191,6 +2176,49 @@ SDValue DAGCombiner::visitADD(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitADDSAT(SDNode *N) { + unsigned Opcode = N->getOpcode(); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N0.getValueType(); + SDLoc DL(N); + + // fold vector ops + if (VT.isVector()) { + // TODO SimplifyVBinOp + + // fold (add_sat x, 0) -> x, vector edition + if (ISD::isBuildVectorAllZeros(N1.getNode())) + return N0; + if (ISD::isBuildVectorAllZeros(N0.getNode())) + return N1; + } + + // fold (add_sat x, undef) -> -1 + if (N0.isUndef() || N1.isUndef()) + return DAG.getAllOnesConstant(DL, VT); + + if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) { + // canonicalize constant to RHS + if (!DAG.isConstantIntBuildVectorOrConstantInt(N1)) + return DAG.getNode(Opcode, DL, VT, N1, N0); + // fold (add_sat c1, c2) -> c3 + return DAG.FoldConstantArithmetic(Opcode, DL, VT, N0.getNode(), + N1.getNode()); + } + + // fold (add_sat x, 0) -> x + if (isNullConstant(N1)) + return N0; + + // If it cannot overflow, transform into an add. + if (Opcode == ISD::UADDSAT) + if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never) + return DAG.getNode(ISD::ADD, DL, VT, N0, N1); + + return SDValue(); +} + static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) { bool Masked = false; @@ -2235,7 +2263,7 @@ SDValue DAGCombiner::visitADDLike(SDValue N0, SDValue N1, SDNode *LocReference) // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n)) if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB && - isNullConstantOrNullSplatConstant(N1.getOperand(0).getOperand(0))) + isNullOrNullSplat(N1.getOperand(0).getOperand(0))) return DAG.getNode(ISD::SUB, DL, VT, N0, DAG.getNode(ISD::SHL, DL, VT, N1.getOperand(0).getOperand(1), @@ -2248,8 +2276,7 @@ SDValue DAGCombiner::visitADDLike(SDValue N0, SDValue N1, SDNode *LocReference) // (add z, (and (sbbl x, x), 1)) -> (sub z, (sbbl x, x)) // and similar xforms where the inner op is either ~0 or 0. - if (NumSignBits == DestBits && - isOneConstantOrOneSplatConstant(N1->getOperand(1))) + if (NumSignBits == DestBits && isOneOrOneSplat(N1->getOperand(1))) return DAG.getNode(ISD::SUB, DL, VT, N0, AndOp0); } @@ -2380,7 +2407,7 @@ SDValue DAGCombiner::visitUADDO(SDNode *N) { DAG.getConstant(0, DL, CarryVT)); // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry. - if (isBitwiseNot(N0) && isOneConstantOrOneSplatConstant(N1)) { + if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) { SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(), DAG.getConstant(0, DL, VT), N0.getOperand(0)); @@ -2539,8 +2566,7 @@ SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, // Since it may not be valid to emit a fold to zero for vector initializers // check if we can before folding. static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT, - SelectionDAG &DAG, bool LegalOperations, - bool LegalTypes) { + SelectionDAG &DAG, bool LegalOperations) { if (!VT.isVector()) return DAG.getConstant(0, DL, VT); if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) @@ -2567,7 +2593,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { // fold (sub x, x) -> 0 // FIXME: Refactor this and xor and other similar operations together. if (N0 == N1) - return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations, LegalTypes); + return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations); if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && DAG.isConstantIntBuildVectorOrConstantInt(N1)) { // fold (sub c1, c2) -> c1-c2 @@ -2586,7 +2612,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { DAG.getConstant(-N1C->getAPIntValue(), DL, VT)); } - if (isNullConstantOrNullSplatConstant(N0)) { + if (isNullOrNullSplat(N0)) { unsigned BitWidth = VT.getScalarSizeInBits(); // Right-shifting everything out but the sign bit followed by negation is // the same as flipping arithmetic/logical shift type without the negation: @@ -2617,12 +2643,11 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { } // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) - if (isAllOnesConstantOrAllOnesSplatConstant(N0)) + if (isAllOnesOrAllOnesSplat(N0)) return DAG.getNode(ISD::XOR, DL, VT, N1, N0); // fold (A - (0-B)) -> A+B - if (N1.getOpcode() == ISD::SUB && - isNullConstantOrNullSplatConstant(N1.getOperand(0))) + if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0))) return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1)); // fold A-(A-B) -> B @@ -2676,14 +2701,14 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { // fold (X - (-Y * Z)) -> (X + (Y * Z)) if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) { if (N1.getOperand(0).getOpcode() == ISD::SUB && - isNullConstantOrNullSplatConstant(N1.getOperand(0).getOperand(0))) { + isNullOrNullSplat(N1.getOperand(0).getOperand(0))) { SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, N1.getOperand(0).getOperand(1), N1.getOperand(1)); return DAG.getNode(ISD::ADD, DL, VT, N0, Mul); } if (N1.getOperand(1).getOpcode() == ISD::SUB && - isNullConstantOrNullSplatConstant(N1.getOperand(1).getOperand(0))) { + isNullOrNullSplat(N1.getOperand(1).getOperand(0))) { SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, N1.getOperand(0), N1.getOperand(1).getOperand(1)); @@ -2756,6 +2781,43 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitSUBSAT(SDNode *N) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N0.getValueType(); + SDLoc DL(N); + + // fold vector ops + if (VT.isVector()) { + // TODO SimplifyVBinOp + + // fold (sub_sat x, 0) -> x, vector edition + if (ISD::isBuildVectorAllZeros(N1.getNode())) + return N0; + } + + // fold (sub_sat x, undef) -> 0 + if (N0.isUndef() || N1.isUndef()) + return DAG.getConstant(0, DL, VT); + + // fold (sub_sat x, x) -> 0 + if (N0 == N1) + return DAG.getConstant(0, DL, VT); + + if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && + DAG.isConstantIntBuildVectorOrConstantInt(N1)) { + // fold (sub_sat c1, c2) -> c3 + return DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, N0.getNode(), + N1.getNode()); + } + + // fold (sub_sat x, 0) -> x + if (isNullConstant(N1)) + return N0; + + return SDValue(); +} + SDValue DAGCombiner::visitSUBC(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -2931,6 +2993,39 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { getShiftAmountTy(N0.getValueType())))); } + // Try to transform multiply-by-(power-of-2 +/- 1) into shift and add/sub. + // mul x, (2^N + 1) --> add (shl x, N), x + // mul x, (2^N - 1) --> sub (shl x, N), x + // Examples: x * 33 --> (x << 5) + x + // x * 15 --> (x << 4) - x + // x * -33 --> -((x << 5) + x) + // x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4) + if (N1IsConst && TLI.decomposeMulByConstant(VT, N1)) { + // TODO: We could handle more general decomposition of any constant by + // having the target set a limit on number of ops and making a + // callback to determine that sequence (similar to sqrt expansion). + unsigned MathOp = ISD::DELETED_NODE; + APInt MulC = ConstValue1.abs(); + if ((MulC - 1).isPowerOf2()) + MathOp = ISD::ADD; + else if ((MulC + 1).isPowerOf2()) + MathOp = ISD::SUB; + + if (MathOp != ISD::DELETED_NODE) { + unsigned ShAmt = MathOp == ISD::ADD ? (MulC - 1).logBase2() + : (MulC + 1).logBase2(); + assert(ShAmt > 0 && ShAmt < VT.getScalarSizeInBits() && + "Not expecting multiply-by-constant that could have simplified"); + SDLoc DL(N); + SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, N0, + DAG.getConstant(ShAmt, DL, VT)); + SDValue R = DAG.getNode(MathOp, DL, VT, Shl, N0); + if (ConstValue1.isNegative()) + R = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), R); + return R; + } + } + // (mul (shl X, c1), c2) -> (mul X, c2 << c1) if (N0.getOpcode() == ISD::SHL && isConstantOrConstantVector(N1, /* NoOpaques */ true) && @@ -2974,7 +3069,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { N0.getOperand(1), N1)); // reassociate mul - if (SDValue RMUL = ReassociateOps(ISD::MUL, SDLoc(N), N0, N1)) + if (SDValue RMUL = ReassociateOps(ISD::MUL, SDLoc(N), N0, N1, N->getFlags())) return RMUL; return SDValue(); @@ -3076,7 +3171,16 @@ static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) { EVT VT = N->getValueType(0); SDLoc DL(N); - if (DAG.isUndef(N->getOpcode(), {N0, N1})) + unsigned Opc = N->getOpcode(); + bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc); + ConstantSDNode *N1C = isConstOrConstSplat(N1); + + // X / undef -> undef + // X % undef -> undef + // X / 0 -> undef + // X % 0 -> undef + // NOTE: This includes vectors where any divisor element is zero/undef. + if (DAG.isUndef(Opc, {N0, N1})) return DAG.getUNDEF(VT); // undef / X -> 0 @@ -3084,6 +3188,26 @@ static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) { if (N0.isUndef()) return DAG.getConstant(0, DL, VT); + // 0 / X -> 0 + // 0 % X -> 0 + ConstantSDNode *N0C = isConstOrConstSplat(N0); + if (N0C && N0C->isNullValue()) + return N0; + + // X / X -> 1 + // X % X -> 0 + if (N0 == N1) + return DAG.getConstant(IsDiv ? 1 : 0, DL, VT); + + // X / 1 -> X + // X % 1 -> 0 + // If this is a boolean op (single-bit element type), we can't have + // division-by-zero or remainder-by-zero, so assume the divisor is 1. + // TODO: Similarly, if we're zero-extending a boolean divisor, then assume + // it's a 1. + if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1)) + return IsDiv ? N0 : DAG.getConstant(0, DL, VT); + return SDValue(); } @@ -3105,9 +3229,6 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { ConstantSDNode *N1C = isConstOrConstSplat(N1); if (N0C && N1C && !N0C->isOpaque() && !N1C->isOpaque()) return DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, N0C, N1C); - // fold (sdiv X, 1) -> X - if (N1C && N1C->isOne()) - return N0; // fold (sdiv X, -1) -> 0-X if (N1C && N1C->isAllOnesValue()) return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0); @@ -3128,8 +3249,19 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1); - if (SDValue V = visitSDIVLike(N0, N1, N)) + if (SDValue V = visitSDIVLike(N0, N1, N)) { + // If the corresponding remainder node exists, update its users with + // (Dividend - (Quotient * Divisor). + if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(), + { N0, N1 })) { + SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1); + SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul); + AddToWorklist(Mul.getNode()); + AddToWorklist(Sub.getNode()); + CombineTo(RemNode, Sub); + } return V; + } // sdiv, srem -> sdivrem // If the divisor is constant, then return DIVREM only if isIntDivCheap() is @@ -3148,8 +3280,6 @@ SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) { EVT CCVT = getSetCCResultType(VT); unsigned BitWidth = VT.getScalarSizeInBits(); - ConstantSDNode *N1C = isConstOrConstSplat(N1); - // Helper for determining whether a value is a power-2 constant scalar or a // vector of such elements. auto IsPowerOfTwo = [](ConstantSDNode *C) { @@ -3166,8 +3296,7 @@ SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) { // FIXME: We check for the exact bit here because the generic lowering gives // better results in that case. The target-specific lowering should learn how // to handle exact sdivs efficiently. - if (!N->getFlags().hasExact() && - ISD::matchUnaryPredicate(N1C ? SDValue(N1C, 0) : N1, IsPowerOfTwo)) { + if (!N->getFlags().hasExact() && ISD::matchUnaryPredicate(N1, IsPowerOfTwo)) { // Target-specific implementation of sdiv x, pow2. if (SDValue Res = BuildSDIVPow2(N)) return Res; @@ -3218,7 +3347,8 @@ SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) { // alternate sequence. Targets may check function attributes for size/speed // trade-offs. AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); - if (N1C && !TLI.isIntDivCheap(N->getValueType(0), Attr)) + if (isConstantOrConstantVector(N1) && + !TLI.isIntDivCheap(N->getValueType(0), Attr)) if (SDValue Op = BuildSDIV(N)) return Op; @@ -3245,9 +3375,6 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, N0C, N1C)) return Folded; - // fold (udiv X, 1) -> X - if (N1C && N1C->isOne()) - return N0; // fold (udiv X, -1) -> select(X == -1, 1, 0) if (N1C && N1C->getAPIntValue().isAllOnesValue()) return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ), @@ -3260,8 +3387,19 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; - if (SDValue V = visitUDIVLike(N0, N1, N)) + if (SDValue V = visitUDIVLike(N0, N1, N)) { + // If the corresponding remainder node exists, update its users with + // (Dividend - (Quotient * Divisor). + if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(), + { N0, N1 })) { + SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1); + SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul); + AddToWorklist(Mul.getNode()); + AddToWorklist(Sub.getNode()); + CombineTo(RemNode, Sub); + } return V; + } // sdiv, srem -> sdivrem // If the divisor is constant, then return DIVREM only if isIntDivCheap() is @@ -3278,8 +3416,6 @@ SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) { SDLoc DL(N); EVT VT = N->getValueType(0); - ConstantSDNode *N1C = isConstOrConstSplat(N1); - // fold (udiv x, (1 << c)) -> x >>u c if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) && DAG.isKnownToBeAPowerOfTwo(N1)) { @@ -3311,7 +3447,8 @@ SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) { // fold (udiv x, c) -> alternate AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); - if (N1C && !TLI.isIntDivCheap(N->getValueType(0), Attr)) + if (isConstantOrConstantVector(N1) && + !TLI.isIntDivCheap(N->getValueType(0), Attr)) if (SDValue Op = BuildUDIV(N)) return Op; @@ -3380,8 +3517,12 @@ SDValue DAGCombiner::visitREM(SDNode *N) { if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) { SDValue OptimizedDiv = isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N); - if (OptimizedDiv.getNode() && OptimizedDiv.getOpcode() != ISD::UDIVREM && - OptimizedDiv.getOpcode() != ISD::SDIVREM) { + if (OptimizedDiv.getNode()) { + // If the equivalent Div node also exists, update its users. + unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV; + if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(), + { N0, N1 })) + CombineTo(DivNode, OptimizedDiv); SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1); SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul); AddToWorklist(OptimizedDiv.getNode()); @@ -3468,6 +3609,19 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) { if (N0.isUndef() || N1.isUndef()) return DAG.getConstant(0, DL, VT); + // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c) + if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) && + DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) { + SDLoc DL(N); + unsigned NumEltBits = VT.getScalarSizeInBits(); + SDValue LogBase2 = BuildLogBase2(N1, DL); + SDValue SRLAmt = DAG.getNode( + ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2); + EVT ShiftVT = getShiftAmountTy(N0.getValueType()); + SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT); + return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc); + } + // If the type twice as wide is legal, transform the mulhu to a wider multiply // plus a shift. if (VT.isSimple() && !VT.isVector()) { @@ -3495,18 +3649,16 @@ SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp, unsigned HiOp) { // If the high half is not needed, just compute the low half. bool HiExists = N->hasAnyUseOfValue(1); - if (!HiExists && - (!LegalOperations || - TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) { + if (!HiExists && (!LegalOperations || + TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) { SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops()); return CombineTo(N, Res, Res); } // If the low half is not needed, just compute the high half. bool LoExists = N->hasAnyUseOfValue(0); - if (!LoExists && - (!LegalOperations || - TLI.isOperationLegal(HiOp, N->getValueType(1)))) { + if (!LoExists && (!LegalOperations || + TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) { SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops()); return CombineTo(N, Res, Res); } @@ -3522,7 +3674,7 @@ SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp, SDValue LoOpt = combine(Lo.getNode()); if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() && (!LegalOperations || - TLI.isOperationLegal(LoOpt.getOpcode(), LoOpt.getValueType()))) + TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType()))) return CombineTo(N, LoOpt, LoOpt); } @@ -3532,7 +3684,7 @@ SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp, SDValue HiOpt = combine(Hi.getNode()); if (HiOpt.getNode() && HiOpt != Hi && (!LegalOperations || - TLI.isOperationLegal(HiOpt.getOpcode(), HiOpt.getValueType()))) + TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType()))) return CombineTo(N, HiOpt, HiOpt); } @@ -3664,59 +3816,94 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) { return SDValue(); } -/// If this is a binary operator with two operands of the same opcode, try to -/// simplify it. -SDValue DAGCombiner::SimplifyBinOpWithSameOpcodeHands(SDNode *N) { +/// If this is a bitwise logic instruction and both operands have the same +/// opcode, try to sink the other opcode after the logic instruction. +SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) { SDValue N0 = N->getOperand(0), N1 = N->getOperand(1); EVT VT = N0.getValueType(); - assert(N0.getOpcode() == N1.getOpcode() && "Bad input!"); + unsigned LogicOpcode = N->getOpcode(); + unsigned HandOpcode = N0.getOpcode(); + assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR || + LogicOpcode == ISD::XOR) && "Expected logic opcode"); + assert(HandOpcode == N1.getOpcode() && "Bad input!"); // Bail early if none of these transforms apply. - if (N0.getNumOperands() == 0) return SDValue(); - - // For each of OP in AND/OR/XOR: - // fold (OP (zext x), (zext y)) -> (zext (OP x, y)) - // fold (OP (sext x), (sext y)) -> (sext (OP x, y)) - // fold (OP (aext x), (aext y)) -> (aext (OP x, y)) - // fold (OP (bswap x), (bswap y)) -> (bswap (OP x, y)) - // fold (OP (trunc x), (trunc y)) -> (trunc (OP x, y)) (if trunc isn't free) - // - // do not sink logical op inside of a vector extend, since it may combine - // into a vsetcc. - EVT Op0VT = N0.getOperand(0).getValueType(); - if ((N0.getOpcode() == ISD::ZERO_EXTEND || - N0.getOpcode() == ISD::SIGN_EXTEND || - N0.getOpcode() == ISD::BSWAP || - // Avoid infinite looping with PromoteIntBinOp. - (N0.getOpcode() == ISD::ANY_EXTEND && - (!LegalTypes || TLI.isTypeDesirableForOp(N->getOpcode(), Op0VT))) || - (N0.getOpcode() == ISD::TRUNCATE && - (!TLI.isZExtFree(VT, Op0VT) || - !TLI.isTruncateFree(Op0VT, VT)) && - TLI.isTypeLegal(Op0VT))) && - !VT.isVector() && - Op0VT == N1.getOperand(0).getValueType() && - (!LegalOperations || TLI.isOperationLegal(N->getOpcode(), Op0VT))) { - SDValue ORNode = DAG.getNode(N->getOpcode(), SDLoc(N0), - N0.getOperand(0).getValueType(), - N0.getOperand(0), N1.getOperand(0)); - AddToWorklist(ORNode.getNode()); - return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, ORNode); - } - - // For each of OP in SHL/SRL/SRA/AND... - // fold (and (OP x, z), (OP y, z)) -> (OP (and x, y), z) - // fold (or (OP x, z), (OP y, z)) -> (OP (or x, y), z) - // fold (xor (OP x, z), (OP y, z)) -> (OP (xor x, y), z) - if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL || - N0.getOpcode() == ISD::SRA || N0.getOpcode() == ISD::AND) && + if (N0.getNumOperands() == 0) + return SDValue(); + + // FIXME: We should check number of uses of the operands to not increase + // the instruction count for all transforms. + + // Handle size-changing casts. + SDValue X = N0.getOperand(0); + SDValue Y = N1.getOperand(0); + EVT XVT = X.getValueType(); + SDLoc DL(N); + if (HandOpcode == ISD::ANY_EXTEND || HandOpcode == ISD::ZERO_EXTEND || + HandOpcode == ISD::SIGN_EXTEND) { + // If both operands have other uses, this transform would create extra + // instructions without eliminating anything. + if (!N0.hasOneUse() && !N1.hasOneUse()) + return SDValue(); + // We need matching integer source types. + if (XVT != Y.getValueType()) + return SDValue(); + // Don't create an illegal op during or after legalization. Don't ever + // create an unsupported vector op. + if ((VT.isVector() || LegalOperations) && + !TLI.isOperationLegalOrCustom(LogicOpcode, XVT)) + return SDValue(); + // Avoid infinite looping with PromoteIntBinOp. + // TODO: Should we apply desirable/legal constraints to all opcodes? + if (HandOpcode == ISD::ANY_EXTEND && LegalTypes && + !TLI.isTypeDesirableForOp(LogicOpcode, XVT)) + return SDValue(); + // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y) + SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); + return DAG.getNode(HandOpcode, DL, VT, Logic); + } + + // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y) + if (HandOpcode == ISD::TRUNCATE) { + // If both operands have other uses, this transform would create extra + // instructions without eliminating anything. + if (!N0.hasOneUse() && !N1.hasOneUse()) + return SDValue(); + // We need matching source types. + if (XVT != Y.getValueType()) + return SDValue(); + // Don't create an illegal op during or after legalization. + if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT)) + return SDValue(); + // Be extra careful sinking truncate. If it's free, there's no benefit in + // widening a binop. Also, don't create a logic op on an illegal type. + if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT)) + return SDValue(); + if (!TLI.isTypeLegal(XVT)) + return SDValue(); + SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); + return DAG.getNode(HandOpcode, DL, VT, Logic); + } + + // For binops SHL/SRL/SRA/AND: + // logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z + if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL || + HandOpcode == ISD::SRA || HandOpcode == ISD::AND) && N0.getOperand(1) == N1.getOperand(1)) { - SDValue ORNode = DAG.getNode(N->getOpcode(), SDLoc(N0), - N0.getOperand(0).getValueType(), - N0.getOperand(0), N1.getOperand(0)); - AddToWorklist(ORNode.getNode()); - return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, - ORNode, N0.getOperand(1)); + // If either operand has other uses, this transform is not an improvement. + if (!N0.hasOneUse() || !N1.hasOneUse()) + return SDValue(); + SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); + return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1)); + } + + // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y) + if (HandOpcode == ISD::BSWAP) { + // If either operand has other uses, this transform is not an improvement. + if (!N0.hasOneUse() || !N1.hasOneUse()) + return SDValue(); + SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); + return DAG.getNode(HandOpcode, DL, VT, Logic); } // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B)) @@ -3726,21 +3913,12 @@ SDValue DAGCombiner::SimplifyBinOpWithSameOpcodeHands(SDNode *N) { // we don't want to undo this promotion. // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper // on scalars. - if ((N0.getOpcode() == ISD::BITCAST || - N0.getOpcode() == ISD::SCALAR_TO_VECTOR) && + if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) && Level <= AfterLegalizeTypes) { - SDValue In0 = N0.getOperand(0); - SDValue In1 = N1.getOperand(0); - EVT In0Ty = In0.getValueType(); - EVT In1Ty = In1.getValueType(); - SDLoc DL(N); - // If both incoming values are integers, and the original types are the - // same. - if (In0Ty.isInteger() && In1Ty.isInteger() && In0Ty == In1Ty) { - SDValue Op = DAG.getNode(N->getOpcode(), DL, In0Ty, In0, In1); - SDValue BC = DAG.getNode(N0.getOpcode(), DL, VT, Op); - AddToWorklist(Op.getNode()); - return BC; + // Input types must be integer and the same. + if (XVT.isInteger() && XVT == Y.getValueType()) { + SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); + return DAG.getNode(HandOpcode, DL, VT, Logic); } } @@ -3756,61 +3934,44 @@ SDValue DAGCombiner::SimplifyBinOpWithSameOpcodeHands(SDNode *N) { // If both shuffles use the same mask, and both shuffles have the same first // or second operand, then it might still be profitable to move the shuffle // after the xor/and/or operation. - if (N0.getOpcode() == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) { - ShuffleVectorSDNode *SVN0 = cast<ShuffleVectorSDNode>(N0); - ShuffleVectorSDNode *SVN1 = cast<ShuffleVectorSDNode>(N1); - - assert(N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() && + if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) { + auto *SVN0 = cast<ShuffleVectorSDNode>(N0); + auto *SVN1 = cast<ShuffleVectorSDNode>(N1); + assert(X.getValueType() == Y.getValueType() && "Inputs to shuffles are not the same type"); // Check that both shuffles use the same mask. The masks are known to be of // the same length because the result vector type is the same. // Check also that shuffles have only one use to avoid introducing extra // instructions. - if (SVN0->hasOneUse() && SVN1->hasOneUse() && - SVN0->getMask().equals(SVN1->getMask())) { - SDValue ShOp = N0->getOperand(1); - - // Don't try to fold this node if it requires introducing a - // build vector of all zeros that might be illegal at this stage. - if (N->getOpcode() == ISD::XOR && !ShOp.isUndef()) { - if (!LegalTypes) - ShOp = DAG.getConstant(0, SDLoc(N), VT); - else - ShOp = SDValue(); - } + if (!SVN0->hasOneUse() || !SVN1->hasOneUse() || + !SVN0->getMask().equals(SVN1->getMask())) + return SDValue(); - // (AND (shuf (A, C), shuf (B, C))) -> shuf (AND (A, B), C) - // (OR (shuf (A, C), shuf (B, C))) -> shuf (OR (A, B), C) - // (XOR (shuf (A, C), shuf (B, C))) -> shuf (XOR (A, B), V_0) - if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) { - SDValue NewNode = DAG.getNode(N->getOpcode(), SDLoc(N), VT, - N0->getOperand(0), N1->getOperand(0)); - AddToWorklist(NewNode.getNode()); - return DAG.getVectorShuffle(VT, SDLoc(N), NewNode, ShOp, - SVN0->getMask()); - } + // Don't try to fold this node if it requires introducing a + // build vector of all zeros that might be illegal at this stage. + SDValue ShOp = N0.getOperand(1); + if (LogicOpcode == ISD::XOR && !ShOp.isUndef()) + ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations); - // Don't try to fold this node if it requires introducing a - // build vector of all zeros that might be illegal at this stage. - ShOp = N0->getOperand(0); - if (N->getOpcode() == ISD::XOR && !ShOp.isUndef()) { - if (!LegalTypes) - ShOp = DAG.getConstant(0, SDLoc(N), VT); - else - ShOp = SDValue(); - } + // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C) + if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) { + SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, + N0.getOperand(0), N1.getOperand(0)); + return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask()); + } - // (AND (shuf (C, A), shuf (C, B))) -> shuf (C, AND (A, B)) - // (OR (shuf (C, A), shuf (C, B))) -> shuf (C, OR (A, B)) - // (XOR (shuf (C, A), shuf (C, B))) -> shuf (V_0, XOR (A, B)) - if (N0->getOperand(0) == N1->getOperand(0) && ShOp.getNode()) { - SDValue NewNode = DAG.getNode(N->getOpcode(), SDLoc(N), VT, - N0->getOperand(1), N1->getOperand(1)); - AddToWorklist(NewNode.getNode()); - return DAG.getVectorShuffle(VT, SDLoc(N), ShOp, NewNode, - SVN0->getMask()); - } + // Don't try to fold this node if it requires introducing a + // build vector of all zeros that might be illegal at this stage. + ShOp = N0.getOperand(0); + if (LogicOpcode == ISD::XOR && !ShOp.isUndef()) + ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations); + + // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B)) + if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) { + SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1), + N1.getOperand(1)); + return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask()); } } @@ -3846,8 +4007,8 @@ SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1, ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get(); bool IsInteger = OpVT.isInteger(); if (LR == RR && CC0 == CC1 && IsInteger) { - bool IsZero = isNullConstantOrNullSplatConstant(LR); - bool IsNeg1 = isAllOnesConstantOrAllOnesSplatConstant(LR); + bool IsZero = isNullOrNullSplat(LR); + bool IsNeg1 = isAllOnesOrAllOnesSplat(LR); // All bits clear? bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero; @@ -4149,7 +4310,7 @@ bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST, } bool DAGCombiner::SearchForAndLoads(SDNode *N, - SmallPtrSetImpl<LoadSDNode*> &Loads, + SmallVectorImpl<LoadSDNode*> &Loads, SmallPtrSetImpl<SDNode*> &NodesWithConsts, ConstantSDNode *Mask, SDNode *&NodeToMask) { @@ -4186,7 +4347,7 @@ bool DAGCombiner::SearchForAndLoads(SDNode *N, // Use LE to convert equal sized loads to zext. if (ExtVT.bitsLE(Load->getMemoryVT())) - Loads.insert(Load); + Loads.push_back(Load); continue; } @@ -4251,7 +4412,7 @@ bool DAGCombiner::BackwardsPropagateMask(SDNode *N, SelectionDAG &DAG) { if (isa<LoadSDNode>(N->getOperand(0))) return false; - SmallPtrSet<LoadSDNode*, 8> Loads; + SmallVector<LoadSDNode*, 8> Loads; SmallPtrSet<SDNode*, 2> NodesWithConsts; SDNode *FixupNode = nullptr; if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) { @@ -4399,7 +4560,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { return DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, N0C, N1C); // canonicalize constant to RHS if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && - !DAG.isConstantIntBuildVectorOrConstantInt(N1)) + !DAG.isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0); // fold (and x, -1) -> x if (isAllOnesConstant(N1)) @@ -4414,7 +4575,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { return NewSel; // reassociate and - if (SDValue RAND = ReassociateOps(ISD::AND, SDLoc(N), N0, N1)) + if (SDValue RAND = ReassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags())) return RAND; // Try to convert a constant mask AND into a shuffle clear mask. @@ -4563,9 +4724,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (SDValue Res = ReduceLoadWidth(N)) { LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND ? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0); - AddToWorklist(N); - CombineTo(LN0, Res, Res.getValue(1)); + DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 0), Res); return SDValue(N, 0); } } @@ -4585,8 +4745,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // Simplify: (and (op x...), (op y...)) -> (op (and x, y)) if (N0.getOpcode() == N1.getOpcode()) - if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N)) - return Tmp; + if (SDValue V = hoistLogicOpWithSameOpcodeHands(N)) + return V; // Masking the negated extension of a boolean is just the zero-extended // boolean: @@ -4596,7 +4756,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // Note: the SimplifyDemandedBits fold below can make an information-losing // transform, and then we have no way to find this better fold. if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) { - if (isNullConstantOrNullSplatConstant(N0.getOperand(0))) { + if (isNullOrNullSplat(N0.getOperand(0))) { SDValue SubRHS = N0.getOperand(1); if (SubRHS.getOpcode() == ISD::ZERO_EXTEND && SubRHS.getOperand(0).getScalarValueSizeInBits() == 1) @@ -5124,16 +5284,16 @@ SDValue DAGCombiner::visitOR(SDNode *N) { return BSwap; // reassociate or - if (SDValue ROR = ReassociateOps(ISD::OR, SDLoc(N), N0, N1)) + if (SDValue ROR = ReassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags())) return ROR; // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2) - // iff (c1 & c2) != 0. - auto MatchIntersect = [](ConstantSDNode *LHS, ConstantSDNode *RHS) { - return LHS->getAPIntValue().intersects(RHS->getAPIntValue()); + // iff (c1 & c2) != 0 or c1/c2 are undef. + auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) { + return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue()); }; if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() && - ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect)) { + ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) { if (SDValue COR = DAG.FoldConstantArithmetic( ISD::OR, SDLoc(N1), VT, N1.getNode(), N0.getOperand(1).getNode())) { SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1); @@ -5144,8 +5304,8 @@ SDValue DAGCombiner::visitOR(SDNode *N) { // Simplify: (or (op x...), (op y...)) -> (op (or x, y)) if (N0.getOpcode() == N1.getOpcode()) - if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N)) - return Tmp; + if (SDValue V = hoistLogicOpWithSameOpcodeHands(N)) + return V; // See if this is some rotate idiom. if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N))) @@ -5257,9 +5417,9 @@ static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift, // Compute the shift amount we need to extract to complete the rotate. const unsigned VTWidth = ShiftedVT.getScalarSizeInBits(); - APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue(); - if (NeededShiftAmt.isNegative()) + if (OppShiftCst->getAPIntValue().ugt(VTWidth)) return SDValue(); + APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue(); // Normalize the bitwidth of the two mul/udiv/shift constant operands. APInt ExtractFromAmt = ExtractFromCst->getAPIntValue(); APInt OppLHSAmt = OppLHSCst->getAPIntValue(); @@ -5340,8 +5500,7 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize, unsigned MaskLoBits = 0; if (Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) { if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) { - KnownBits Known; - DAG.computeKnownBits(Neg.getOperand(0), Known); + KnownBits Known = DAG.computeKnownBits(Neg.getOperand(0)); unsigned Bits = Log2_64(EltSize); if (NegC->getAPIntValue().getActiveBits() <= Bits && ((NegC->getAPIntValue() | Known.Zero).countTrailingOnes() >= Bits)) { @@ -5363,8 +5522,7 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize, // Pos'. The truncation is redundant for the purpose of the equality. if (MaskLoBits && Pos.getOpcode() == ISD::AND) { if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) { - KnownBits Known; - DAG.computeKnownBits(Pos.getOperand(0), Known); + KnownBits Known = DAG.computeKnownBits(Pos.getOperand(0)); if (PosC->getAPIntValue().getActiveBits() <= MaskLoBits && ((PosC->getAPIntValue() | Known.Zero).countTrailingOnes() >= MaskLoBits)) @@ -5894,7 +6052,7 @@ SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) { assert(N->getOpcode() == ISD::XOR); // Don't touch 'not' (i.e. where y = -1). - if (isAllOnesConstantOrAllOnesSplatConstant(N->getOperand(1))) + if (isAllOnesOrAllOnesSplat(N->getOperand(1))) return SDValue(); EVT VT = N->getValueType(0); @@ -5911,7 +6069,7 @@ SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) { SDValue Xor0 = Xor.getOperand(0); SDValue Xor1 = Xor.getOperand(1); // Don't touch 'not' (i.e. where y = -1). - if (isAllOnesConstantOrAllOnesSplatConstant(Xor1)) + if (isAllOnesOrAllOnesSplat(Xor1)) return false; if (Other == Xor0) std::swap(Xor0, Xor1); @@ -5977,8 +6135,9 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { } // fold (xor undef, undef) -> 0. This is a common idiom (misuse). + SDLoc DL(N); if (N0.isUndef() && N1.isUndef()) - return DAG.getConstant(0, SDLoc(N), VT); + return DAG.getConstant(0, DL, VT); // fold (xor x, undef) -> undef if (N0.isUndef()) return N0; @@ -5988,11 +6147,11 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); ConstantSDNode *N1C = getAsNonOpaqueConstant(N1); if (N0C && N1C) - return DAG.FoldConstantArithmetic(ISD::XOR, SDLoc(N), VT, N0C, N1C); + return DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, N0C, N1C); // canonicalize constant to RHS if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && !DAG.isConstantIntBuildVectorOrConstantInt(N1)) - return DAG.getNode(ISD::XOR, SDLoc(N), VT, N1, N0); + return DAG.getNode(ISD::XOR, DL, VT, N1, N0); // fold (xor x, 0) -> x if (isNullConstant(N1)) return N0; @@ -6001,19 +6160,18 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { return NewSel; // reassociate xor - if (SDValue RXOR = ReassociateOps(ISD::XOR, SDLoc(N), N0, N1)) + if (SDValue RXOR = ReassociateOps(ISD::XOR, DL, N0, N1, N->getFlags())) return RXOR; // fold !(x cc y) -> (x !cc y) + unsigned N0Opcode = N0.getOpcode(); SDValue LHS, RHS, CC; if (TLI.isConstTrueVal(N1.getNode()) && isSetCCEquivalent(N0, LHS, RHS, CC)) { - bool isInt = LHS.getValueType().isInteger(); ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(), - isInt); - + LHS.getValueType().isInteger()); if (!LegalOperations || TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) { - switch (N0.getOpcode()) { + switch (N0Opcode) { default: llvm_unreachable("Unhandled SetCC Equivalent!"); case ISD::SETCC: @@ -6026,54 +6184,74 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { } // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y))) - if (isOneConstant(N1) && N0.getOpcode() == ISD::ZERO_EXTEND && - N0.getNode()->hasOneUse() && + if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() && isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){ SDValue V = N0.getOperand(0); - SDLoc DL(N0); - V = DAG.getNode(ISD::XOR, DL, V.getValueType(), V, - DAG.getConstant(1, DL, V.getValueType())); + SDLoc DL0(N0); + V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V, + DAG.getConstant(1, DL0, V.getValueType())); AddToWorklist(V.getNode()); - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, V); + return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V); } // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() && - (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND)) { + (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) { SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1); if (isOneUseSetCC(RHS) || isOneUseSetCC(LHS)) { - unsigned NewOpcode = N0.getOpcode() == ISD::AND ? ISD::OR : ISD::AND; + unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND; LHS = DAG.getNode(ISD::XOR, SDLoc(LHS), VT, LHS, N1); // LHS = ~LHS RHS = DAG.getNode(ISD::XOR, SDLoc(RHS), VT, RHS, N1); // RHS = ~RHS AddToWorklist(LHS.getNode()); AddToWorklist(RHS.getNode()); - return DAG.getNode(NewOpcode, SDLoc(N), VT, LHS, RHS); + return DAG.getNode(NewOpcode, DL, VT, LHS, RHS); } } // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants if (isAllOnesConstant(N1) && N0.hasOneUse() && - (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND)) { + (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) { SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1); if (isa<ConstantSDNode>(RHS) || isa<ConstantSDNode>(LHS)) { - unsigned NewOpcode = N0.getOpcode() == ISD::AND ? ISD::OR : ISD::AND; + unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND; LHS = DAG.getNode(ISD::XOR, SDLoc(LHS), VT, LHS, N1); // LHS = ~LHS RHS = DAG.getNode(ISD::XOR, SDLoc(RHS), VT, RHS, N1); // RHS = ~RHS AddToWorklist(LHS.getNode()); AddToWorklist(RHS.getNode()); - return DAG.getNode(NewOpcode, SDLoc(N), VT, LHS, RHS); + return DAG.getNode(NewOpcode, DL, VT, LHS, RHS); } } // fold (xor (and x, y), y) -> (and (not x), y) - if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() && - N0->getOperand(1) == N1) { - SDValue X = N0->getOperand(0); + if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) { + SDValue X = N0.getOperand(0); SDValue NotX = DAG.getNOT(SDLoc(X), X, VT); AddToWorklist(NotX.getNode()); - return DAG.getNode(ISD::AND, SDLoc(N), VT, NotX, N1); + return DAG.getNode(ISD::AND, DL, VT, NotX, N1); + } + + if ((N0Opcode == ISD::SRL || N0Opcode == ISD::SHL) && N0.hasOneUse()) { + ConstantSDNode *XorC = isConstOrConstSplat(N1); + ConstantSDNode *ShiftC = isConstOrConstSplat(N0.getOperand(1)); + unsigned BitWidth = VT.getScalarSizeInBits(); + if (XorC && ShiftC) { + // Don't crash on an oversized shift. We can not guarantee that a bogus + // shift has been simplified to undef. + uint64_t ShiftAmt = ShiftC->getLimitedValue(); + if (ShiftAmt < BitWidth) { + APInt Ones = APInt::getAllOnesValue(BitWidth); + Ones = N0Opcode == ISD::SHL ? Ones.shl(ShiftAmt) : Ones.lshr(ShiftAmt); + if (XorC->getAPIntValue() == Ones) { + // If the xor constant is a shifted -1, do a 'not' before the shift: + // xor (X << ShiftC), XorC --> (not X) << ShiftC + // xor (X >> ShiftC), XorC --> (not X) >> ShiftC + SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT); + return DAG.getNode(N0Opcode, DL, VT, Not, N0.getOperand(1)); + } + } + } } // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X) if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) { - SDValue A = N0.getOpcode() == ISD::ADD ? N0 : N1; - SDValue S = N0.getOpcode() == ISD::SRA ? N0 : N1; + SDValue A = N0Opcode == ISD::ADD ? N0 : N1; + SDValue S = N0Opcode == ISD::SRA ? N0 : N1; if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) { SDValue A0 = A.getOperand(0), A1 = A.getOperand(1); SDValue S0 = S.getOperand(0); @@ -6081,14 +6259,14 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { unsigned OpSizeInBits = VT.getScalarSizeInBits(); if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1))) if (C->getAPIntValue() == (OpSizeInBits - 1)) - return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0); + return DAG.getNode(ISD::ABS, DL, VT, S0); } } } // fold (xor x, x) -> 0 if (N0 == N1) - return tryFoldToZero(SDLoc(N), TLI, VT, DAG, LegalOperations, LegalTypes); + return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations); // fold (xor (shl 1, x), -1) -> (rotl ~1, x) // Here is a concrete example of this equivalence: @@ -6108,17 +6286,16 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { // consistent result. // - Pushing the zero left requires shifting one bits in from the right. // A rotate left of ~1 is a nice way of achieving the desired result. - if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0.getOpcode() == ISD::SHL - && isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) { - SDLoc DL(N); + if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL && + isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) { return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT), N0.getOperand(1)); } // Simplify: xor (op x...), (op y...) -> (op (xor x, y)) - if (N0.getOpcode() == N1.getOpcode()) - if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N)) - return Tmp; + if (N0Opcode == N1.getOpcode()) + if (SDValue V = hoistLogicOpWithSameOpcodeHands(N)) + return V; // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable if (SDValue MM = unfoldMaskedMerge(N)) @@ -6134,6 +6311,10 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { /// Handle transforms common to the three shifts, when the shift amount is a /// constant. SDValue DAGCombiner::visitShiftByConstant(SDNode *N, ConstantSDNode *Amt) { + // Do not turn a 'not' into a regular xor. + if (isBitwiseNot(N->getOperand(0))) + return SDValue(); + SDNode *LHS = N->getOperand(0).getNode(); if (!LHS->hasOneUse()) return SDValue(); @@ -6191,7 +6372,7 @@ SDValue DAGCombiner::visitShiftByConstant(SDNode *N, ConstantSDNode *Amt) { return SDValue(); } - if (!TLI.isDesirableToCommuteWithShift(LHS)) + if (!TLI.isDesirableToCommuteWithShift(N, Level)) return SDValue(); // Fold the constants, shifting the binop RHS by the shift amount. @@ -6239,9 +6420,16 @@ SDValue DAGCombiner::visitRotate(SDNode *N) { unsigned Bitsize = VT.getScalarSizeInBits(); // fold (rot x, 0) -> x - if (isNullConstantOrNullSplatConstant(N1)) + if (isNullOrNullSplat(N1)) return N0; + // fold (rot x, c) -> x iff (c % BitSize) == 0 + if (isPowerOf2_32(Bitsize) && Bitsize > 1) { + APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1); + if (DAG.MaskedValueIsZero(N1, ModuloMask)) + return N0; + } + // fold (rot x, c) -> (rot x, c % BitSize) if (ConstantSDNode *Cst = isConstOrConstSplat(N1)) { if (Cst->getAPIntValue().uge(Bitsize)) { @@ -6284,6 +6472,9 @@ SDValue DAGCombiner::visitRotate(SDNode *N) { SDValue DAGCombiner::visitSHL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); + if (SDValue V = DAG.simplifyShift(N0, N1)) + return V; + EVT VT = N0.getValueType(); unsigned OpSizeInBits = VT.getScalarSizeInBits(); @@ -6318,22 +6509,6 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); if (N0C && N1C && !N1C->isOpaque()) return DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, N0C, N1C); - // fold (shl 0, x) -> 0 - if (isNullConstantOrNullSplatConstant(N0)) - return N0; - // fold (shl x, c >= size(x)) -> undef - // NOTE: ALL vector elements must be too big to avoid partial UNDEFs. - auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) { - return Val->getAPIntValue().uge(OpSizeInBits); - }; - if (ISD::matchUnaryPredicate(N1, MatchShiftTooBig)) - return DAG.getUNDEF(VT); - // fold (shl x, 0) -> x - if (N1C && N1C->isNullValue()) - return N0; - // fold (shl undef, x) -> 0 - if (N0.isUndef()) - return DAG.getConstant(0, SDLoc(N), VT); if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; @@ -6454,7 +6629,8 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // (and (srl x, (sub c1, c2), MASK) // Only fold this if the inner shift has no other uses -- if it does, folding // this will increase the total number of instructions. - if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse()) { + if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse() && + TLI.shouldFoldShiftPairToMask(N, Level)) { if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { uint64_t c1 = N0C1->getZExtValue(); if (c1 < OpSizeInBits) { @@ -6495,7 +6671,8 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) && N0.getNode()->hasOneUse() && isConstantOrConstantVector(N1, /* No Opaques */ true) && - isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true)) { + isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) && + TLI.isDesirableToCommuteWithShift(N, Level)) { SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1); SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1); AddToWorklist(Shl0.getNode()); @@ -6522,6 +6699,9 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { SDValue DAGCombiner::visitSRA(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); + if (SDValue V = DAG.simplifyShift(N0, N1)) + return V; + EVT VT = N0.getValueType(); unsigned OpSizeInBits = VT.getScalarSizeInBits(); @@ -6542,16 +6722,6 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); if (N0C && N1C && !N1C->isOpaque()) return DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, N0C, N1C); - // fold (sra x, c >= size(x)) -> undef - // NOTE: ALL vector elements must be too big to avoid partial UNDEFs. - auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) { - return Val->getAPIntValue().uge(OpSizeInBits); - }; - if (ISD::matchUnaryPredicate(N1, MatchShiftTooBig)) - return DAG.getUNDEF(VT); - // fold (sra x, 0) -> x - if (N1C && N1C->isNullValue()) - return N0; if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; @@ -6571,31 +6741,30 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { } // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2)) + // clamp (add c1, c2) to max shift. if (N0.getOpcode() == ISD::SRA) { SDLoc DL(N); EVT ShiftVT = N1.getValueType(); + EVT ShiftSVT = ShiftVT.getScalarType(); + SmallVector<SDValue, 16> ShiftValues; - auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS, - ConstantSDNode *RHS) { + auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) { APInt c1 = LHS->getAPIntValue(); APInt c2 = RHS->getAPIntValue(); zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - return (c1 + c2).uge(OpSizeInBits); - }; - if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) - return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), - DAG.getConstant(OpSizeInBits - 1, DL, ShiftVT)); - - auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS, - ConstantSDNode *RHS) { - APInt c1 = LHS->getAPIntValue(); - APInt c2 = RHS->getAPIntValue(); - zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - return (c1 + c2).ult(OpSizeInBits); + APInt Sum = c1 + c2; + unsigned ShiftSum = + Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue(); + ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT)); + return true; }; - if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { - SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); - return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), Sum); + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) { + SDValue ShiftValue; + if (VT.isVector()) + ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues); + else + ShiftValue = ShiftValues[0]; + return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue); } } @@ -6689,6 +6858,9 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { SDValue DAGCombiner::visitSRL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); + if (SDValue V = DAG.simplifyShift(N0, N1)) + return V; + EVT VT = N0.getValueType(); unsigned OpSizeInBits = VT.getScalarSizeInBits(); @@ -6703,19 +6875,6 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); if (N0C && N1C && !N1C->isOpaque()) return DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, N0C, N1C); - // fold (srl 0, x) -> 0 - if (isNullConstantOrNullSplatConstant(N0)) - return N0; - // fold (srl x, c >= size(x)) -> undef - // NOTE: ALL vector elements must be too big to avoid partial UNDEFs. - auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) { - return Val->getAPIntValue().uge(OpSizeInBits); - }; - if (ISD::matchUnaryPredicate(N1, MatchShiftTooBig)) - return DAG.getUNDEF(VT); - // fold (srl x, 0) -> x - if (N1C && N1C->isNullValue()) - return N0; if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; @@ -6819,8 +6978,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit). if (N1C && N0.getOpcode() == ISD::CTLZ && N1C->getAPIntValue() == Log2_32(OpSizeInBits)) { - KnownBits Known; - DAG.computeKnownBits(N0.getOperand(0), Known); + KnownBits Known = DAG.computeKnownBits(N0.getOperand(0)); // If any of the input bits are KnownOne, then the input couldn't be all // zeros, thus the result of the srl will always be zero. @@ -6906,6 +7064,41 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitFunnelShift(SDNode *N) { + EVT VT = N->getValueType(0); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + SDValue N2 = N->getOperand(2); + bool IsFSHL = N->getOpcode() == ISD::FSHL; + unsigned BitWidth = VT.getScalarSizeInBits(); + + // fold (fshl N0, N1, 0) -> N0 + // fold (fshr N0, N1, 0) -> N1 + if (isPowerOf2_32(BitWidth)) + if (DAG.MaskedValueIsZero( + N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1))) + return IsFSHL ? N0 : N1; + + // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth) + if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) { + if (Cst->getAPIntValue().uge(BitWidth)) { + uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth); + return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1, + DAG.getConstant(RotAmt, SDLoc(N), N2.getValueType())); + } + } + + // fold (fshl N0, N0, N2) -> (rotl N0, N2) + // fold (fshr N0, N0, N2) -> (rotr N0, N2) + // TODO: Investigate flipping this rotate if only one is legal, if funnel shift + // is legal as well we might be better off avoiding non-constant (BW - N2). + unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR; + if (N0 == N1 && hasOperation(RotOpc, VT)) + return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2); + + return SDValue(); +} + SDValue DAGCombiner::visitABS(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); @@ -7012,6 +7205,16 @@ SDValue DAGCombiner::visitCTPOP(SDNode *N) { return SDValue(); } +// FIXME: This should be checking for no signed zeros on individual operands, as +// well as no nans. +static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS, SDValue RHS) { + const TargetOptions &Options = DAG.getTarget().Options; + EVT VT = LHS.getValueType(); + + return Options.NoSignedZerosFPMath && VT.isFloatingPoint() && + DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS); +} + /// Generate Min/Max node static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, SDValue RHS, SDValue True, SDValue False, @@ -7020,6 +7223,7 @@ static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, if (!(LHS == True && RHS == False) && !(LHS == False && RHS == True)) return SDValue(); + EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT); switch (CC) { case ISD::SETOLT: case ISD::SETOLE: @@ -7027,8 +7231,15 @@ static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, case ISD::SETLE: case ISD::SETULT: case ISD::SETULE: { + // Since it's known never nan to get here already, either fminnum or + // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is + // expanded in terms of it. + unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE; + if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT)) + return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS); + unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM; - if (TLI.isOperationLegal(Opcode, VT)) + if (TLI.isOperationLegalOrCustom(Opcode, TransformVT)) return DAG.getNode(Opcode, DL, VT, LHS, RHS); return SDValue(); } @@ -7038,8 +7249,12 @@ static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, case ISD::SETGE: case ISD::SETUGT: case ISD::SETUGE: { + unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE; + if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT)) + return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS); + unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM; - if (TLI.isOperationLegal(Opcode, VT)) + if (TLI.isOperationLegalOrCustom(Opcode, TransformVT)) return DAG.getNode(Opcode, DL, VT, LHS, RHS); return SDValue(); } @@ -7150,15 +7365,8 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { EVT VT0 = N0.getValueType(); SDLoc DL(N); - // fold (select C, X, X) -> X - if (N1 == N2) - return N1; - - if (const ConstantSDNode *N0C = dyn_cast<const ConstantSDNode>(N0)) { - // fold (select true, X, Y) -> X - // fold (select false, X, Y) -> Y - return !N0C->isNullValue() ? N1 : N2; - } + if (SDValue V = DAG.simplifySelect(N0, N1, N2)) + return V; // fold (select X, X, Y) -> (or X, Y) // fold (select X, 1, Y) -> (or C, Y) @@ -7264,32 +7472,54 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { return DAG.getNode(ISD::SELECT, DL, VT, N0->getOperand(0), N2, N1); } - // fold selects based on a setcc into other things, such as min/max/abs + // Fold selects based on a setcc into other things, such as min/max/abs. if (N0.getOpcode() == ISD::SETCC) { - // select x, y (fcmp lt x, y) -> fminnum x, y - // select x, y (fcmp gt x, y) -> fmaxnum x, y - // - // This is OK if we don't care about what happens if either operand is a - // NaN. - // - - // FIXME: Instead of testing for UnsafeFPMath, this should be checking for - // no signed zeros as well as no nans. - const TargetOptions &Options = DAG.getTarget().Options; - if (Options.UnsafeFPMath && VT.isFloatingPoint() && N0.hasOneUse() && - DAG.isKnownNeverNaN(N1) && DAG.isKnownNeverNaN(N2)) { - ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get(); + SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1); + ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get(); - if (SDValue FMinMax = combineMinNumMaxNum( - DL, VT, N0.getOperand(0), N0.getOperand(1), N1, N2, CC, TLI, DAG)) + // select (fcmp lt x, y), x, y -> fminnum x, y + // select (fcmp gt x, y), x, y -> fmaxnum x, y + // + // This is OK if we don't care what happens if either operand is a NaN. + if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2)) + if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2, + CC, TLI, DAG)) return FMinMax; + + // Use 'unsigned add with overflow' to optimize an unsigned saturating add. + // This is conservatively limited to pre-legal-operations to give targets + // a chance to reverse the transform if they want to do that. Also, it is + // unlikely that the pattern would be formed late, so it's probably not + // worth going through the other checks. + if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) && + CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) && + N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) { + auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1)); + auto *NotC = dyn_cast<ConstantSDNode>(Cond1); + if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) { + // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) --> + // uaddo Cond0, C; select uaddo.1, -1, uaddo.0 + // + // The IR equivalent of this transform would have this form: + // %a = add %x, C + // %c = icmp ugt %x, ~C + // %r = select %c, -1, %a + // => + // %u = call {iN,i1} llvm.uadd.with.overflow(%x, C) + // %u0 = extractvalue %u, 0 + // %u1 = extractvalue %u, 1 + // %r = select %u1, -1, %u0 + SDVTList VTs = DAG.getVTList(VT, VT0); + SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1)); + return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0)); + } } - if ((!LegalOperations && - TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT)) || - TLI.isOperationLegal(ISD::SELECT_CC, VT)) - return DAG.getNode(ISD::SELECT_CC, DL, VT, N0.getOperand(0), - N0.getOperand(1), N1, N2, N0.getOperand(2)); + if (TLI.isOperationLegal(ISD::SELECT_CC, VT) || + (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) + return DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1, N2, + N0.getOperand(2)); + return SimplifySelect(DL, N0, N1, N2); } @@ -7388,7 +7618,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) { if (TLI.getTypeAction(*DAG.getContext(), Data.getValueType()) != TargetLowering::TypeSplitVector) return SDValue(); - SDValue MaskLo, MaskHi, Lo, Hi; + SDValue MaskLo, MaskHi; std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); EVT LoVT, HiVT; @@ -7416,17 +7646,15 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) { Alignment, MSC->getAAInfo(), MSC->getRanges()); SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo, Scale }; - Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(), - DL, OpsLo, MMO); + SDValue Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), + DataLo.getValueType(), DL, OpsLo, MMO); - SDValue OpsHi[] = { Chain, DataHi, MaskHi, BasePtr, IndexHi, Scale }; - Hi = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), - DL, OpsHi, MMO); - - AddToWorklist(Lo.getNode()); - AddToWorklist(Hi.getNode()); - - return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Lo, Hi); + // The order of the Scatter operation after split is well defined. The "Hi" + // part comes after the "Lo". So these two operations should be chained one + // after another. + SDValue OpsHi[] = { Lo, DataHi, MaskHi, BasePtr, IndexHi, Scale }; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), + DL, OpsHi, MMO); } SDValue DAGCombiner::visitMSTORE(SDNode *N) { @@ -7525,9 +7753,9 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) { SDValue MaskLo, MaskHi, Lo, Hi; std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); - SDValue Src0 = MGT->getValue(); - SDValue Src0Lo, Src0Hi; - std::tie(Src0Lo, Src0Hi) = DAG.SplitVector(Src0, DL); + SDValue PassThru = MGT->getPassThru(); + SDValue PassThruLo, PassThruHi; + std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, DL); EVT LoVT, HiVT; std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT); @@ -7550,11 +7778,11 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) { MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), Alignment, MGT->getAAInfo(), MGT->getRanges()); - SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo, Scale }; + SDValue OpsLo[] = { Chain, PassThruLo, MaskLo, BasePtr, IndexLo, Scale }; Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, DL, OpsLo, MMO); - SDValue OpsHi[] = { Chain, Src0Hi, MaskHi, BasePtr, IndexHi, Scale }; + SDValue OpsHi[] = { Chain, PassThruHi, MaskHi, BasePtr, IndexHi, Scale }; Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, DL, OpsHi, MMO); @@ -7599,9 +7827,9 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) { SDValue MaskLo, MaskHi, Lo, Hi; std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); - SDValue Src0 = MLD->getSrc0(); - SDValue Src0Lo, Src0Hi; - std::tie(Src0Lo, Src0Hi) = DAG.SplitVector(Src0, DL); + SDValue PassThru = MLD->getPassThru(); + SDValue PassThruLo, PassThruHi; + std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, DL); EVT LoVT, HiVT; std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MLD->getValueType(0)); @@ -7625,8 +7853,8 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) { MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), Alignment, MLD->getAAInfo(), MLD->getRanges()); - Lo = DAG.getMaskedLoad(LoVT, DL, Chain, Ptr, MaskLo, Src0Lo, LoMemVT, MMO, - ISD::NON_EXTLOAD, MLD->isExpandingLoad()); + Lo = DAG.getMaskedLoad(LoVT, DL, Chain, Ptr, MaskLo, PassThruLo, LoMemVT, + MMO, ISD::NON_EXTLOAD, MLD->isExpandingLoad()); Ptr = TLI.IncrementMemoryAddress(Ptr, MaskLo, DL, LoMemVT, DAG, MLD->isExpandingLoad()); @@ -7637,8 +7865,8 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) { MachineMemOperand::MOLoad, HiMemVT.getStoreSize(), SecondHalfAlignment, MLD->getAAInfo(), MLD->getRanges()); - Hi = DAG.getMaskedLoad(HiVT, DL, Chain, Ptr, MaskHi, Src0Hi, HiMemVT, MMO, - ISD::NON_EXTLOAD, MLD->isExpandingLoad()); + Hi = DAG.getMaskedLoad(HiVT, DL, Chain, Ptr, MaskHi, PassThruHi, HiMemVT, + MMO, ISD::NON_EXTLOAD, MLD->isExpandingLoad()); AddToWorklist(Lo.getNode()); AddToWorklist(Hi.getNode()); @@ -7717,9 +7945,8 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { SDValue N2 = N->getOperand(2); SDLoc DL(N); - // fold (vselect C, X, X) -> X - if (N1 == N2) - return N1; + if (SDValue V = DAG.simplifySelect(N0, N1, N2)) + return V; // Canonicalize integer abs. // vselect (setg[te] X, 0), X, -X -> @@ -7754,12 +7981,26 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { return DAG.getNode(ISD::XOR, DL, VT, Add, Shift); } + // vselect x, y (fcmp lt x, y) -> fminnum x, y + // vselect x, y (fcmp gt x, y) -> fmaxnum x, y + // + // This is OK if we don't care about what happens if either operand is a + // NaN. + // + EVT VT = N->getValueType(0); + if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N0.getOperand(0), N0.getOperand(1))) { + ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get(); + if (SDValue FMinMax = combineMinNumMaxNum( + DL, VT, N0.getOperand(0), N0.getOperand(1), N1, N2, CC, TLI, DAG)) + return FMinMax; + } + // If this select has a condition (setcc) with narrower operands than the // select, try to widen the compare to match the select width. // TODO: This should be extended to handle any constant. // TODO: This could be extended to handle non-loading patterns, but that // requires thorough testing to avoid regressions. - if (isNullConstantOrNullSplatConstant(RHS)) { + if (isNullOrNullSplat(RHS)) { EVT NarrowVT = LHS.getValueType(); EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger(); EVT SetCCVT = getSetCCResultType(LHS.getValueType()); @@ -7902,9 +8143,8 @@ SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) { /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND). /// Vector extends are not folded if operations are legal; this is to /// avoid introducing illegal build_vector dag nodes. -static SDNode *tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI, - SelectionDAG &DAG, bool LegalTypes, - bool LegalOperations) { +static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI, + SelectionDAG &DAG, bool LegalTypes) { unsigned Opcode = N->getOpcode(); SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); @@ -7918,16 +8158,15 @@ static SDNode *tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI, // fold (zext c1) -> c1 // fold (aext c1) -> c1 if (isa<ConstantSDNode>(N0)) - return DAG.getNode(Opcode, SDLoc(N), VT, N0).getNode(); + return DAG.getNode(Opcode, SDLoc(N), VT, N0); // fold (sext (build_vector AllConstants) -> (build_vector AllConstants) // fold (zext (build_vector AllConstants) -> (build_vector AllConstants) // fold (aext (build_vector AllConstants) -> (build_vector AllConstants) EVT SVT = VT.getScalarType(); - if (!(VT.isVector() && - (!LegalTypes || (!LegalOperations && TLI.isTypeLegal(SVT))) && + if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) && ISD::isBuildVectorOfConstantSDNodes(N0.getNode()))) - return nullptr; + return SDValue(); // We can fold this node into a build_vector. unsigned VTBits = SVT.getSizeInBits(); @@ -7936,10 +8175,15 @@ static SDNode *tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI, unsigned NumElts = VT.getVectorNumElements(); SDLoc DL(N); - for (unsigned i=0; i != NumElts; ++i) { - SDValue Op = N0->getOperand(i); - if (Op->isUndef()) { - Elts.push_back(DAG.getUNDEF(SVT)); + // For zero-extensions, UNDEF elements still guarantee to have the upper + // bits set to zero. + bool IsZext = + Opcode == ISD::ZERO_EXTEND || Opcode == ISD::ZERO_EXTEND_VECTOR_INREG; + + for (unsigned i = 0; i != NumElts; ++i) { + SDValue Op = N0.getOperand(i); + if (Op.isUndef()) { + Elts.push_back(IsZext ? DAG.getConstant(0, DL, SVT) : DAG.getUNDEF(SVT)); continue; } @@ -7953,7 +8197,7 @@ static SDNode *tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI, Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT)); } - return DAG.getBuildVector(VT, DL, Elts).getNode(); + return DAG.getBuildVector(VT, DL, Elts); } // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this: @@ -8269,7 +8513,7 @@ static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner, LoadSDNode *LN0 = cast<LoadSDNode>(N0); EVT MemVT = LN0->getMemoryVT(); - if ((LegalOperations || LN0->isVolatile()) && + if ((LegalOperations || LN0->isVolatile() || VT.isVector()) && !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT)) return {}; @@ -8359,9 +8603,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); - if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes, - LegalOperations)) - return SDValue(Res, 0); + if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) + return Res; // fold (sext (sext x)) -> (sext x) // fold (sext (aext x)) -> (sext x) @@ -8498,21 +8741,24 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { // if this is the case. EVT SVT = getSetCCResultType(N00VT); - // We know that the # elements of the results is the same as the - // # elements of the compare (and the # elements of the compare result - // for that matter). Check to see that they are the same size. If so, - // we know that the element size of the sext'd result matches the - // element size of the compare operands. - if (VT.getSizeInBits() == SVT.getSizeInBits()) - return DAG.getSetCC(DL, VT, N00, N01, CC); - - // If the desired elements are smaller or larger than the source - // elements, we can use a matching integer vector type and then - // truncate/sign extend. - EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger(); - if (SVT == MatchingVecType) { - SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC); - return DAG.getSExtOrTrunc(VsetCC, DL, VT); + // If we already have the desired type, don't change it. + if (SVT != N0.getValueType()) { + // We know that the # elements of the results is the same as the + // # elements of the compare (and the # elements of the compare result + // for that matter). Check to see that they are the same size. If so, + // we know that the element size of the sext'd result matches the + // element size of the compare operands. + if (VT.getSizeInBits() == SVT.getSizeInBits()) + return DAG.getSetCC(DL, VT, N00, N01, CC); + + // If the desired elements are smaller or larger than the source + // elements, we can use a matching integer vector type and then + // truncate/sign extend. + EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger(); + if (SVT == MatchingVecType) { + SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC); + return DAG.getSExtOrTrunc(VsetCC, DL, VT); + } } } @@ -8569,40 +8815,37 @@ static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op, KnownBits &Known) { if (N->getOpcode() == ISD::TRUNCATE) { Op = N->getOperand(0); - DAG.computeKnownBits(Op, Known); + Known = DAG.computeKnownBits(Op); return true; } - if (N->getOpcode() != ISD::SETCC || N->getValueType(0) != MVT::i1 || - cast<CondCodeSDNode>(N->getOperand(2))->get() != ISD::SETNE) + if (N.getOpcode() != ISD::SETCC || + N.getValueType().getScalarType() != MVT::i1 || + cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE) return false; SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); assert(Op0.getValueType() == Op1.getValueType()); - if (isNullConstant(Op0)) + if (isNullOrNullSplat(Op0)) Op = Op1; - else if (isNullConstant(Op1)) + else if (isNullOrNullSplat(Op1)) Op = Op0; else return false; - DAG.computeKnownBits(Op, Known); + Known = DAG.computeKnownBits(Op); - if (!(Known.Zero | 1).isAllOnesValue()) - return false; - - return true; + return (Known.Zero | 1).isAllOnesValue(); } SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); - if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes, - LegalOperations)) - return SDValue(Res, 0); + if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) + return Res; // fold (zext (zext x)) -> (zext x) // fold (zext (aext x)) -> (zext x) @@ -8613,17 +8856,16 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { // fold (zext (truncate x)) -> (zext x) or // (zext (truncate x)) -> (truncate x) // This is valid when the truncated bits of x are already zero. - // FIXME: We should extend this to work for vectors too. SDValue Op; KnownBits Known; - if (!VT.isVector() && isTruncateOf(DAG, N0, Op, Known)) { + if (isTruncateOf(DAG, N0, Op, Known)) { APInt TruncatedBits = - (Op.getValueSizeInBits() == N0.getValueSizeInBits()) ? - APInt(Op.getValueSizeInBits(), 0) : - APInt::getBitsSet(Op.getValueSizeInBits(), - N0.getValueSizeInBits(), - std::min(Op.getValueSizeInBits(), - VT.getSizeInBits())); + (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ? + APInt(Op.getScalarValueSizeInBits(), 0) : + APInt::getBitsSet(Op.getScalarValueSizeInBits(), + N0.getScalarValueSizeInBits(), + std::min(Op.getScalarValueSizeInBits(), + VT.getScalarSizeInBits())); if (TruncatedBits.isSubsetOf(Known.Zero)) return DAG.getZExtOrTrunc(Op, SDLoc(N), VT); } @@ -8851,9 +9093,8 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); - if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes, - LegalOperations)) - return SDValue(Res, 0); + if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) + return Res; // fold (aext (aext x)) -> (aext x) // fold (aext (zext x)) -> (zext x) @@ -8968,17 +9209,16 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0), N0.getOperand(1), cast<CondCodeSDNode>(N0.getOperand(2))->get()); + // If the desired elements are smaller or larger than the source // elements we can use a matching integer vector type and then // truncate/any extend - else { - EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger(); - SDValue VsetCC = - DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0), - N0.getOperand(1), - cast<CondCodeSDNode>(N0.getOperand(2))->get()); - return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT); - } + EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger(); + SDValue VsetCC = + DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0), + N0.getOperand(1), + cast<CondCodeSDNode>(N0.getOperand(2))->get()); + return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT); } // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc @@ -9025,6 +9265,26 @@ SDValue DAGCombiner::visitAssertExt(SDNode *N) { return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert); } + // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller + // than X. Just move the AssertZext in front of the truncate and drop the + // AssertSExt. + if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() && + N0.getOperand(0).getOpcode() == ISD::AssertSext && + Opcode == ISD::AssertZext) { + SDValue BigA = N0.getOperand(0); + EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT(); + assert(BigA_AssertVT.bitsLE(N0.getValueType()) && + "Asserting zero/sign-extended bits to a type larger than the " + "truncated destination does not provide information"); + + if (AssertVT.bitsLT(BigA_AssertVT)) { + SDLoc DL(N); + SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(), + BigA.getOperand(0), N1); + return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert); + } + } + return SDValue(); } @@ -9046,6 +9306,8 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { if (VT.isVector()) return SDValue(); + unsigned ShAmt = 0; + bool HasShiftedOffset = false; // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then // extended to VT. if (Opc == ISD::SIGN_EXTEND_INREG) { @@ -9073,15 +9335,25 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { } else if (Opc == ISD::AND) { // An AND with a constant mask is the same as a truncate + zero-extend. auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1)); - if (!AndC || !AndC->getAPIntValue().isMask()) + if (!AndC) + return SDValue(); + + const APInt &Mask = AndC->getAPIntValue(); + unsigned ActiveBits = 0; + if (Mask.isMask()) { + ActiveBits = Mask.countTrailingOnes(); + } else if (Mask.isShiftedMask()) { + ShAmt = Mask.countTrailingZeros(); + APInt ShiftedMask = Mask.lshr(ShAmt); + ActiveBits = ShiftedMask.countTrailingOnes(); + HasShiftedOffset = true; + } else return SDValue(); - unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes(); ExtType = ISD::ZEXTLOAD; ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits); } - unsigned ShAmt = 0; if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) { SDValue SRL = N0; if (auto *ConstShift = dyn_cast<ConstantSDNode>(SRL.getOperand(1))) { @@ -9150,13 +9422,16 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { if (!isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt)) return SDValue(); - // For big endian targets, we need to adjust the offset to the pointer to - // load the correct bytes. - if (DAG.getDataLayout().isBigEndian()) { + auto AdjustBigEndianShift = [&](unsigned ShAmt) { unsigned LVTStoreBits = LN0->getMemoryVT().getStoreSizeInBits(); unsigned EVTStoreBits = ExtVT.getStoreSizeInBits(); - ShAmt = LVTStoreBits - EVTStoreBits - ShAmt; - } + return LVTStoreBits - EVTStoreBits - ShAmt; + }; + + // For big endian targets, we need to adjust the offset to the pointer to + // load the correct bytes. + if (DAG.getDataLayout().isBigEndian()) + ShAmt = AdjustBigEndianShift(ShAmt); EVT PtrType = N0.getOperand(1).getValueType(); uint64_t PtrOff = ShAmt / 8; @@ -9204,6 +9479,21 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy)); } + if (HasShiftedOffset) { + // Recalculate the shift amount after it has been altered to calculate + // the offset. + if (DAG.getDataLayout().isBigEndian()) + ShAmt = AdjustBigEndianShift(ShAmt); + + // We're using a shifted mask, so the load now has an offset. This means + // that data has been loaded into the lower bytes than it would have been + // before, so we need to shl the loaded data into the correct position in the + // register. + SDValue ShiftC = DAG.getConstant(ShAmt, DL, VT); + Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC); + DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result); + } + // Return the new loaded value. return Result; } @@ -9235,12 +9525,15 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { // fold (sext_in_reg (sext x)) -> (sext x) // fold (sext_in_reg (aext x)) -> (sext x) - // if x is small enough. + // if x is small enough or if we know that x has more than 1 sign bit and the + // sign_extend_inreg is extending from one of them. if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) { SDValue N00 = N0.getOperand(0); - if (N00.getScalarValueSizeInBits() <= EVTBits && + unsigned N00Bits = N00.getScalarValueSizeInBits(); + if ((N00Bits <= EVTBits || + (N00Bits - DAG.ComputeNumSignBits(N00)) < EVTBits) && (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT))) - return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1); + return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00); } // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x) @@ -9250,7 +9543,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { N0.getOperand(0).getScalarValueSizeInBits() == EVTBits) { if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT)) - return DAG.getSignExtendVectorInReg(N0.getOperand(0), SDLoc(N), VT); + return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT, + N0.getOperand(0)); } // fold (sext_in_reg (zext x)) -> (sext x) @@ -9345,9 +9639,11 @@ SDValue DAGCombiner::visitSIGN_EXTEND_VECTOR_INREG(SDNode *N) { if (N0.isUndef()) return DAG.getUNDEF(VT); - if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes, - LegalOperations)) - return SDValue(Res, 0); + if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) + return Res; + + if (SimplifyDemandedVectorElts(SDValue(N, 0))) + return SDValue(N, 0); return SDValue(); } @@ -9359,9 +9655,11 @@ SDValue DAGCombiner::visitZERO_EXTEND_VECTOR_INREG(SDNode *N) { if (N0.isUndef()) return DAG.getUNDEF(VT); - if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes, - LegalOperations)) - return SDValue(Res, 0); + if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) + return Res; + + if (SimplifyDemandedVectorElts(SDValue(N, 0))) + return SDValue(N, 0); return SDValue(); } @@ -9458,8 +9756,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) && TLI.isTypeDesirableForOp(ISD::SHL, VT)) { SDValue Amt = N0.getOperand(1); - KnownBits Known; - DAG.computeKnownBits(Amt, Known); + KnownBits Known = DAG.computeKnownBits(Amt); unsigned Size = VT.getScalarSizeInBits(); if (Known.getBitWidth() - Known.countMinLeadingZeros() <= Log2_32(Size)) { SDLoc SL(N); @@ -9636,6 +9933,32 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N)) return NewVSel; + // Narrow a suitable binary operation with a non-opaque constant operand by + // moving it ahead of the truncate. This is limited to pre-legalization + // because targets may prefer a wider type during later combines and invert + // this transform. + switch (N0.getOpcode()) { + case ISD::ADD: + case ISD::SUB: + case ISD::MUL: + case ISD::AND: + case ISD::OR: + case ISD::XOR: + if (!LegalOperations && N0.hasOneUse() && + (isConstantOrConstantVector(N0.getOperand(0), true) || + isConstantOrConstantVector(N0.getOperand(1), true))) { + // TODO: We already restricted this to pre-legalization, but for vectors + // we are extra cautious to not create an unsupported operation. + // Target-specific changes are likely needed to avoid regressions here. + if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) { + SDLoc DL(N); + SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0)); + SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1)); + return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR); + } + } + } + return SDValue(); } @@ -9694,11 +10017,11 @@ static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG, if (!VT.isFloatingPoint() || !TLI.hasBitPreservingFPLogic(VT)) return SDValue(); - // TODO: Use splat values for the constant-checking below and remove this - // restriction. + // TODO: Handle cases where the integer constant is a different scalar + // bitwidth to the FP. SDValue N0 = N->getOperand(0); EVT SourceVT = N0.getValueType(); - if (SourceVT.isVector()) + if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits()) return SDValue(); unsigned FPOpcode; @@ -9706,25 +10029,35 @@ static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG, switch (N0.getOpcode()) { case ISD::AND: FPOpcode = ISD::FABS; - SignMask = ~APInt::getSignMask(SourceVT.getSizeInBits()); + SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits()); break; case ISD::XOR: FPOpcode = ISD::FNEG; - SignMask = APInt::getSignMask(SourceVT.getSizeInBits()); + SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits()); + break; + case ISD::OR: + FPOpcode = ISD::FABS; + SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits()); break; - // TODO: ISD::OR --> ISD::FNABS? default: return SDValue(); } // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X + // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) -> + // fneg (fabs X) SDValue LogicOp0 = N0.getOperand(0); - ConstantSDNode *LogicOp1 = dyn_cast<ConstantSDNode>(N0.getOperand(1)); + ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true); if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask && LogicOp0.getOpcode() == ISD::BITCAST && - LogicOp0->getOperand(0).getValueType() == VT) - return DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0->getOperand(0)); + LogicOp0.getOperand(0).getValueType() == VT) { + SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0.getOperand(0)); + NumFPLogicOpsConv++; + if (N0.getOpcode() == ISD::OR) + return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp); + return FPOp; + } return SDValue(); } @@ -9737,33 +10070,32 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { return DAG.getUNDEF(VT); // If the input is a BUILD_VECTOR with all constant elements, fold this now. - // Only do this before legalize, since afterward the target may be depending - // on the bitconvert. + // Only do this before legalize types, since we might create an illegal + // scalar type. Even if we knew we wouldn't create an illegal scalar type + // we can only do this before legalize ops, since the target maybe + // depending on the bitcast. // First check to see if this is all constant. if (!LegalTypes && N0.getOpcode() == ISD::BUILD_VECTOR && N0.getNode()->hasOneUse() && - VT.isVector()) { - bool isSimple = cast<BuildVectorSDNode>(N0)->isConstant(); - - EVT DestEltVT = N->getValueType(0).getVectorElementType(); - assert(!DestEltVT.isVector() && - "Element type of vector ValueType must not be vector!"); - if (isSimple) - return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(), DestEltVT); - } + VT.isVector() && cast<BuildVectorSDNode>(N0)->isConstant()) + return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(), + VT.getVectorElementType()); // If the input is a constant, let getNode fold it. - // We always need to check that this is just a fp -> int or int -> conversion - // otherwise we will get back N which will confuse the caller into thinking - // we used CombineTo. This can block target combines from running. If we can't - // allowed legal operations, we need to ensure the resulting operation will be - // legal. - // TODO: Maybe we should check that the return value isn't N explicitly? - if ((isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() && - (!LegalOperations || TLI.isOperationLegal(ISD::ConstantFP, VT))) || - (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() && - (!LegalOperations || TLI.isOperationLegal(ISD::Constant, VT)))) - return DAG.getBitcast(VT, N0); + if (isa<ConstantSDNode>(N0) || isa<ConstantFPSDNode>(N0)) { + // If we can't allow illegal operations, we need to check that this is just + // a fp -> int or int -> conversion and that the resulting operation will + // be legal. + if (!LegalOperations || + (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() && + TLI.isOperationLegal(ISD::ConstantFP, VT)) || + (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() && + TLI.isOperationLegal(ISD::Constant, VT))) { + SDValue C = DAG.getBitcast(VT, N0); + if (C.getNode() != N) + return C; + } + } // (conv (conv x, t1), t2) -> (conv x, t2) if (N0.getOpcode() == ISD::BITCAST) @@ -9772,12 +10104,16 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { // fold (conv (load x)) -> (load (conv*)x) // If the resultant load doesn't need a higher alignment than the original! if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() && - // Do not change the width of a volatile load. - !cast<LoadSDNode>(N0)->isVolatile() && // Do not remove the cast if the types differ in endian layout. TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) == TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) && - (!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) && + // If the load is volatile, we only want to change the load type if the + // resulting load is legal. Otherwise we might increase the number of + // memory accesses. We don't care if the original type was legal or not + // as we assume software couldn't rely on the number of accesses of an + // illegal type. + ((!LegalOperations && !cast<LoadSDNode>(N0)->isVolatile()) || + TLI.isOperationLegal(ISD::LOAD, VT)) && TLI.isLoadBitCastBeneficial(N0.getValueType(), VT)) { LoadSDNode *LN0 = cast<LoadSDNode>(N0); unsigned OrigAlign = LN0->getAlignment(); @@ -9934,7 +10270,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { // float vectors bitcast to integer vectors) into shuffles. // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1) if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() && - N0->getOpcode() == ISD::VECTOR_SHUFFLE && + N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() && VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() && !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) { ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0); @@ -10000,15 +10336,6 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { // If this is a conversion of N elements of one type to N elements of another // type, convert each element. This handles FP<->INT cases. if (SrcBitSize == DstBitSize) { - EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, - BV->getValueType(0).getVectorNumElements()); - - // Due to the FP element handling below calling this routine recursively, - // we can end up with a scalar-to-vector node here. - if (BV->getOpcode() == ISD::SCALAR_TO_VECTOR) - return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(BV), VT, - DAG.getBitcast(DstEltVT, BV->getOperand(0))); - SmallVector<SDValue, 8> Ops; for (SDValue Op : BV->op_values()) { // If the vector element type is not legal, the BUILD_VECTOR operands @@ -10018,6 +10345,8 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { Ops.push_back(DAG.getBitcast(DstEltVT, Op)); AddToWorklist(Ops.back().getNode()); } + EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, + BV->getValueType(0).getVectorNumElements()); return DAG.getBuildVector(VT, SDLoc(BV), Ops); } @@ -10651,17 +10980,18 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) { unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); - // fold (fmul (fadd x, +1.0), y) -> (fma x, y, y) - // fold (fmul (fadd x, -1.0), y) -> (fma x, y, (fneg y)) + // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y) + // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y)) auto FuseFADD = [&](SDValue X, SDValue Y, const SDNodeFlags Flags) { if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) { - auto XC1 = isConstOrConstSplatFP(X.getOperand(1)); - if (XC1 && XC1->isExactlyValue(+1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, - Y, Flags); - if (XC1 && XC1->isExactlyValue(-1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, - DAG.getNode(ISD::FNEG, SL, VT, Y), Flags); + if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) { + if (C->isExactlyValue(+1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, + Y, Flags); + if (C->isExactlyValue(-1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, + DAG.getNode(ISD::FNEG, SL, VT, Y), Flags); + } } return SDValue(); }; @@ -10671,29 +11001,30 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) { if (SDValue FMA = FuseFADD(N1, N0, Flags)) return FMA; - // fold (fmul (fsub +1.0, x), y) -> (fma (fneg x), y, y) - // fold (fmul (fsub -1.0, x), y) -> (fma (fneg x), y, (fneg y)) - // fold (fmul (fsub x, +1.0), y) -> (fma x, y, (fneg y)) - // fold (fmul (fsub x, -1.0), y) -> (fma x, y, y) + // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y) + // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y)) + // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y)) + // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y) auto FuseFSUB = [&](SDValue X, SDValue Y, const SDNodeFlags Flags) { if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) { - auto XC0 = isConstOrConstSplatFP(X.getOperand(0)); - if (XC0 && XC0->isExactlyValue(+1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, - Y, Flags); - if (XC0 && XC0->isExactlyValue(-1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, - DAG.getNode(ISD::FNEG, SL, VT, Y), Flags); - - auto XC1 = isConstOrConstSplatFP(X.getOperand(1)); - if (XC1 && XC1->isExactlyValue(+1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, - DAG.getNode(ISD::FNEG, SL, VT, Y), Flags); - if (XC1 && XC1->isExactlyValue(-1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, - Y, Flags); + if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) { + if (C0->isExactlyValue(+1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, + Y, Flags); + if (C0->isExactlyValue(-1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, + DAG.getNode(ISD::FNEG, SL, VT, Y), Flags); + } + if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) { + if (C1->isExactlyValue(+1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, + DAG.getNode(ISD::FNEG, SL, VT, Y), Flags); + if (C1->isExactlyValue(-1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, + Y, Flags); + } } return SDValue(); }; @@ -10706,14 +11037,6 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) { return SDValue(); } -static bool isFMulNegTwo(SDValue &N) { - if (N.getOpcode() != ISD::FMUL) - return false; - if (ConstantFPSDNode *CFP = isConstOrConstSplatFP(N.getOperand(1))) - return CFP->isExactlyValue(-2.0); - return false; -} - SDValue DAGCombiner::visitFADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -10737,6 +11060,12 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { if (N0CFP && !N1CFP) return DAG.getNode(ISD::FADD, DL, VT, N1, N0, Flags); + // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math) + ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true); + if (N1C && N1C->isZero()) + if (N1C->isNegative() || Options.UnsafeFPMath || Flags.hasNoSignedZeros()) + return N0; + if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; @@ -10752,23 +11081,24 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { return DAG.getNode(ISD::FSUB, DL, VT, N1, GetNegatedExpression(N0, DAG, LegalOperations), Flags); - // fold (fadd A, (fmul B, -2.0)) -> (fsub A, (fadd B, B)) - // fold (fadd (fmul B, -2.0), A) -> (fsub A, (fadd B, B)) - if ((isFMulNegTwo(N0) && N0.hasOneUse()) || - (isFMulNegTwo(N1) && N1.hasOneUse())) { - bool N1IsFMul = isFMulNegTwo(N1); - SDValue AddOp = N1IsFMul ? N1.getOperand(0) : N0.getOperand(0); - SDValue Add = DAG.getNode(ISD::FADD, DL, VT, AddOp, AddOp, Flags); - return DAG.getNode(ISD::FSUB, DL, VT, N1IsFMul ? N0 : N1, Add, Flags); - } + auto isFMulNegTwo = [](SDValue FMul) { + if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL) + return false; + auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true); + return C && C->isExactlyValue(-2.0); + }; - ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1); - if (N1C && N1C->isZero()) { - if (N1C->isNegative() || Options.UnsafeFPMath || - Flags.hasNoSignedZeros()) { - // fold (fadd A, 0) -> A - return N0; - } + // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B) + if (isFMulNegTwo(N0)) { + SDValue B = N0.getOperand(0); + SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B, Flags); + return DAG.getNode(ISD::FSUB, DL, VT, N1, Add, Flags); + } + // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B) + if (isFMulNegTwo(N1)) { + SDValue B = N1.getOperand(0); + SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B, Flags); + return DAG.getNode(ISD::FSUB, DL, VT, N0, Add, Flags); } // No FP constant should be created after legalization as Instruction @@ -10887,8 +11217,8 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { SDValue DAGCombiner::visitFSUB(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0); - ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1); + ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true); + ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true); EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; @@ -10920,9 +11250,10 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { return DAG.getConstantFP(0.0f, DL, VT); } - // (fsub 0, B) -> -B + // (fsub -0.0, N1) -> -N1 if (N0CFP && N0CFP->isZero()) { - if (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) { + if (N0CFP->isNegative() || + (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) { if (isNegatibleForFree(N1, LegalOperations, TLI, &Options)) return GetNegatedExpression(N1, DAG, LegalOperations); if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT)) @@ -10930,27 +11261,22 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { } } + if ((Options.UnsafeFPMath || + (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) + && N1.getOpcode() == ISD::FADD) { + // X - (X + Y) -> -Y + if (N0 == N1->getOperand(0)) + return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1), Flags); + // X - (Y + X) -> -Y + if (N0 == N1->getOperand(1)) + return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0), Flags); + } + // fold (fsub A, (fneg B)) -> (fadd A, B) if (isNegatibleForFree(N1, LegalOperations, TLI, &Options)) return DAG.getNode(ISD::FADD, DL, VT, N0, GetNegatedExpression(N1, DAG, LegalOperations), Flags); - // If 'unsafe math' is enabled, fold lots of things. - if (Options.UnsafeFPMath) { - // (fsub x, (fadd x, y)) -> (fneg y) - // (fsub x, (fadd y, x)) -> (fneg y) - if (N1.getOpcode() == ISD::FADD) { - SDValue N10 = N1->getOperand(0); - SDValue N11 = N1->getOperand(1); - - if (N10 == N0 && isNegatibleForFree(N11, LegalOperations, TLI, &Options)) - return GetNegatedExpression(N11, DAG, LegalOperations); - - if (N11 == N0 && isNegatibleForFree(N10, LegalOperations, TLI, &Options)) - return GetNegatedExpression(N10, DAG, LegalOperations); - } - } - // FSUB -> FMA combines: if (SDValue Fused = visitFSUBForFMACombine(N)) { AddToWorklist(Fused.getNode()); @@ -10963,8 +11289,8 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { SDValue DAGCombiner::visitFMUL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0); - ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1); + ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true); + ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true); EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; @@ -11002,26 +11328,16 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) { // fmul (fmul X, C1), C2 -> fmul X, C1 * C2 - if (N0.getOpcode() == ISD::FMUL) { - // Fold scalars or any vector constants (not just splats). - // This fold is done in general by InstCombine, but extra fmul insts - // may have been generated during lowering. + if (isConstantFPBuildVectorOrConstantFP(N1) && + N0.getOpcode() == ISD::FMUL) { SDValue N00 = N0.getOperand(0); SDValue N01 = N0.getOperand(1); - auto *BV1 = dyn_cast<BuildVectorSDNode>(N1); - auto *BV00 = dyn_cast<BuildVectorSDNode>(N00); - auto *BV01 = dyn_cast<BuildVectorSDNode>(N01); - - // Check 1: Make sure that the first operand of the inner multiply is NOT - // a constant. Otherwise, we may induce infinite looping. - if (!(isConstOrConstSplatFP(N00) || (BV00 && BV00->isConstant()))) { - // Check 2: Make sure that the second operand of the inner multiply and - // the second operand of the outer multiply are constants. - if ((N1CFP && isConstOrConstSplatFP(N01)) || - (BV1 && BV01 && BV1->isConstant() && BV01->isConstant())) { - SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1, Flags); - return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts, Flags); - } + // Avoid an infinite loop by making sure that N00 is not a constant + // (the inner multiply has not been constant folded yet). + if (isConstantFPBuildVectorOrConstantFP(N01) && + !isConstantFPBuildVectorOrConstantFP(N00)) { + SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1, Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts, Flags); } } @@ -11445,15 +11761,15 @@ static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) { SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0); - ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1); + bool N0CFP = isConstantFPBuildVectorOrConstantFP(N0); + bool N1CFP = isConstantFPBuildVectorOrConstantFP(N1); EVT VT = N->getValueType(0); if (N0CFP && N1CFP) // Constant fold return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1); - if (N1CFP) { - const APFloat &V = N1CFP->getValueAPF(); + if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) { + const APFloat &V = N1C->getValueAPF(); // copysign(x, c1) -> fabs(x) iff ispos(c1) // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1) if (!V.isNegative()) { @@ -11489,6 +11805,72 @@ SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitFPOW(SDNode *N) { + ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1)); + if (!ExponentC) + return SDValue(); + + // Try to convert x ** (1/3) into cube root. + // TODO: Handle the various flavors of long double. + // TODO: Since we're approximating, we don't need an exact 1/3 exponent. + // Some range near 1/3 should be fine. + EVT VT = N->getValueType(0); + if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) || + (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) { + // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0. + // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf. + // pow(-val, 1/3) = nan; cbrt(-val) = -num. + // For regular numbers, rounding may cause the results to differ. + // Therefore, we require { nsz ninf nnan afn } for this transform. + // TODO: We could select out the special cases if we don't have nsz/ninf. + SDNodeFlags Flags = N->getFlags(); + if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() || + !Flags.hasApproximateFuncs()) + return SDValue(); + + // Do not create a cbrt() libcall if the target does not have it, and do not + // turn a pow that has lowering support into a cbrt() libcall. + if (!DAG.getLibInfo().has(LibFunc_cbrt) || + (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) && + DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT))) + return SDValue(); + + return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0), Flags); + } + + // Try to convert x ** (1/4) into square roots. + // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case. + // TODO: This could be extended (using a target hook) to handle smaller + // power-of-2 fractional exponents. + if (ExponentC->getValueAPF().isExactlyValue(0.25)) { + // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0. + // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN. + // For regular numbers, rounding may cause the results to differ. + // Therefore, we require { nsz ninf afn } for this transform. + // TODO: We could select out the special cases if we don't have nsz/ninf. + SDNodeFlags Flags = N->getFlags(); + if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || + !Flags.hasApproximateFuncs()) + return SDValue(); + + // Don't double the number of libcalls. We are trying to inline fast code. + if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT)) + return SDValue(); + + // Assume that libcalls are the smallest code. + // TODO: This restriction should probably be lifted for vectors. + if (DAG.getMachineFunction().getFunction().optForSize()) + return SDValue(); + + // pow(X, 0.25) --> sqrt(sqrt(X)) + SDLoc DL(N); + SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0), Flags); + return DAG.getNode(ISD::FSQRT, DL, VT, Sqrt, Flags); + } + + return SDValue(); +} + static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI) { // This optimization is guarded by a function attribute because it may produce @@ -11538,8 +11920,8 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) { // If the input is a legal type, and SINT_TO_FP is not legal on this target, // but UINT_TO_FP is legal on this target, try to convert. - if (!TLI.isOperationLegalOrCustom(ISD::SINT_TO_FP, OpVT) && - TLI.isOperationLegalOrCustom(ISD::UINT_TO_FP, OpVT)) { + if (!hasOperation(ISD::SINT_TO_FP, OpVT) && + hasOperation(ISD::UINT_TO_FP, OpVT)) { // If the sign bit is known to be zero, we can change this to UINT_TO_FP. if (DAG.SignBitIsZero(N0)) return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0); @@ -11595,8 +11977,8 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) { // If the input is a legal type, and UINT_TO_FP is not legal on this target, // but SINT_TO_FP is legal on this target, try to convert. - if (!TLI.isOperationLegalOrCustom(ISD::UINT_TO_FP, OpVT) && - TLI.isOperationLegalOrCustom(ISD::SINT_TO_FP, OpVT)) { + if (!hasOperation(ISD::UINT_TO_FP, OpVT) && + hasOperation(ISD::SINT_TO_FP, OpVT)) { // If the sign bit is known to be zero, we can change this to SINT_TO_FP. if (DAG.SignBitIsZero(N0)) return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0); @@ -11917,7 +12299,8 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitFMINNUM(SDNode *N) { +static SDValue visitFMinMax(SelectionDAG &DAG, SDNode *N, + APFloat (*Op)(const APFloat &, const APFloat &)) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); @@ -11927,36 +12310,31 @@ SDValue DAGCombiner::visitFMINNUM(SDNode *N) { if (N0CFP && N1CFP) { const APFloat &C0 = N0CFP->getValueAPF(); const APFloat &C1 = N1CFP->getValueAPF(); - return DAG.getConstantFP(minnum(C0, C1), SDLoc(N), VT); + return DAG.getConstantFP(Op(C0, C1), SDLoc(N), VT); } // Canonicalize to constant on RHS. if (isConstantFPBuildVectorOrConstantFP(N0) && - !isConstantFPBuildVectorOrConstantFP(N1)) - return DAG.getNode(ISD::FMINNUM, SDLoc(N), VT, N1, N0); + !isConstantFPBuildVectorOrConstantFP(N1)) + return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0); return SDValue(); } -SDValue DAGCombiner::visitFMAXNUM(SDNode *N) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - EVT VT = N->getValueType(0); - const ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0); - const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1); +SDValue DAGCombiner::visitFMINNUM(SDNode *N) { + return visitFMinMax(DAG, N, minnum); +} - if (N0CFP && N1CFP) { - const APFloat &C0 = N0CFP->getValueAPF(); - const APFloat &C1 = N1CFP->getValueAPF(); - return DAG.getConstantFP(maxnum(C0, C1), SDLoc(N), VT); - } +SDValue DAGCombiner::visitFMAXNUM(SDNode *N) { + return visitFMinMax(DAG, N, maxnum); +} - // Canonicalize to constant on RHS. - if (isConstantFPBuildVectorOrConstantFP(N0) && - !isConstantFPBuildVectorOrConstantFP(N1)) - return DAG.getNode(ISD::FMAXNUM, SDLoc(N), VT, N1, N0); +SDValue DAGCombiner::visitFMINIMUM(SDNode *N) { + return visitFMinMax(DAG, N, minimum); +} - return SDValue(); +SDValue DAGCombiner::visitFMAXIMUM(SDNode *N) { + return visitFMinMax(DAG, N, maximum); } SDValue DAGCombiner::visitFABS(SDNode *N) { @@ -11976,11 +12354,8 @@ SDValue DAGCombiner::visitFABS(SDNode *N) { if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN) return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0)); - // Transform fabs(bitconvert(x)) -> bitconvert(x & ~sign) to avoid loading - // constant pool values. - if (!TLI.isFAbsFree(VT) && - N0.getOpcode() == ISD::BITCAST && - N0.getNode()->hasOneUse()) { + // fabs(bitcast(x)) -> bitcast(x & ~sign) to avoid constant pool loads. + if (!TLI.isFAbsFree(VT) && N0.getOpcode() == ISD::BITCAST && N0.hasOneUse()) { SDValue Int = N0.getOperand(0); EVT IntVT = Int.getValueType(); if (IntVT.isInteger() && !IntVT.isVector()) { @@ -12512,8 +12887,15 @@ bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) { if (TryNext) continue; - // Check for #2 - if (!Op->isPredecessorOf(N) && !N->isPredecessorOf(Op)) { + // Check for #2. + SmallPtrSet<const SDNode *, 32> Visited; + SmallVector<const SDNode *, 8> Worklist; + // Ptr is predecessor to both N and Op. + Visited.insert(Ptr.getNode()); + Worklist.push_back(N); + Worklist.push_back(Op); + if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) && + !SDNode::hasPredecessorHelper(Op, Visited, Worklist)) { SDValue Result = isLoad ? DAG.getIndexedLoad(SDValue(N,0), SDLoc(N), BasePtr, Offset, AM) @@ -12571,6 +12953,157 @@ SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) { return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc); } +static inline int numVectorEltsOrZero(EVT T) { + return T.isVector() ? T.getVectorNumElements() : 0; +} + +bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) { + Val = ST->getValue(); + EVT STType = Val.getValueType(); + EVT STMemType = ST->getMemoryVT(); + if (STType == STMemType) + return true; + if (isTypeLegal(STMemType)) + return false; // fail. + if (STType.isFloatingPoint() && STMemType.isFloatingPoint() && + TLI.isOperationLegal(ISD::FTRUNC, STMemType)) { + Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val); + return true; + } + if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) && + STType.isInteger() && STMemType.isInteger()) { + Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val); + return true; + } + if (STType.getSizeInBits() == STMemType.getSizeInBits()) { + Val = DAG.getBitcast(STMemType, Val); + return true; + } + return false; // fail. +} + +bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) { + EVT LDMemType = LD->getMemoryVT(); + EVT LDType = LD->getValueType(0); + assert(Val.getValueType() == LDMemType && + "Attempting to extend value of non-matching type"); + if (LDType == LDMemType) + return true; + if (LDMemType.isInteger() && LDType.isInteger()) { + switch (LD->getExtensionType()) { + case ISD::NON_EXTLOAD: + Val = DAG.getBitcast(LDType, Val); + return true; + case ISD::EXTLOAD: + Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val); + return true; + case ISD::SEXTLOAD: + Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val); + return true; + case ISD::ZEXTLOAD: + Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val); + return true; + } + } + return false; +} + +SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) { + if (OptLevel == CodeGenOpt::None || LD->isVolatile()) + return SDValue(); + SDValue Chain = LD->getOperand(0); + StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode()); + if (!ST || ST->isVolatile()) + return SDValue(); + + EVT LDType = LD->getValueType(0); + EVT LDMemType = LD->getMemoryVT(); + EVT STMemType = ST->getMemoryVT(); + EVT STType = ST->getValue().getValueType(); + + BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG); + BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG); + int64_t Offset; + if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset)) + return SDValue(); + + // Normalize for Endianness. After this Offset=0 will denote that the least + // significant bit in the loaded value maps to the least significant bit in + // the stored value). With Offset=n (for n > 0) the loaded value starts at the + // n:th least significant byte of the stored value. + if (DAG.getDataLayout().isBigEndian()) + Offset = (STMemType.getStoreSizeInBits() - + LDMemType.getStoreSizeInBits()) / 8 - Offset; + + // Check that the stored value cover all bits that are loaded. + bool STCoversLD = + (Offset >= 0) && + (Offset * 8 + LDMemType.getSizeInBits() <= STMemType.getSizeInBits()); + + auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue { + if (LD->isIndexed()) { + bool IsSub = (LD->getAddressingMode() == ISD::PRE_DEC || + LD->getAddressingMode() == ISD::POST_DEC); + unsigned Opc = IsSub ? ISD::SUB : ISD::ADD; + SDValue Idx = DAG.getNode(Opc, SDLoc(LD), LD->getOperand(1).getValueType(), + LD->getOperand(1), LD->getOperand(2)); + SDValue Ops[] = {Val, Idx, Chain}; + return CombineTo(LD, Ops, 3); + } + return CombineTo(LD, Val, Chain); + }; + + if (!STCoversLD) + return SDValue(); + + // Memory as copy space (potentially masked). + if (Offset == 0 && LDType == STType && STMemType == LDMemType) { + // Simple case: Direct non-truncating forwarding + if (LDType.getSizeInBits() == LDMemType.getSizeInBits()) + return ReplaceLd(LD, ST->getValue(), Chain); + // Can we model the truncate and extension with an and mask? + if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() && + !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) { + // Mask to size of LDMemType + auto Mask = + DAG.getConstant(APInt::getLowBitsSet(STType.getSizeInBits(), + STMemType.getSizeInBits()), + SDLoc(ST), STType); + auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask); + return ReplaceLd(LD, Val, Chain); + } + } + + // TODO: Deal with nonzero offset. + if (LD->getBasePtr().isUndef() || Offset != 0) + return SDValue(); + // Model necessary truncations / extenstions. + SDValue Val; + // Truncate Value To Stored Memory Size. + do { + if (!getTruncatedStoreValue(ST, Val)) + continue; + if (!isTypeLegal(LDMemType)) + continue; + if (STMemType != LDMemType) { + // TODO: Support vectors? This requires extract_subvector/bitcast. + if (!STMemType.isVector() && !LDMemType.isVector() && + STMemType.isInteger() && LDMemType.isInteger()) + Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val); + else + continue; + } + if (!extendLoadedValueToExtension(LD, Val)) + continue; + return ReplaceLd(LD, Val, Chain); + } while (false); + + // On failure, cleanup dead nodes we may have created. + if (Val->use_empty()) + deleteAndRecombine(Val.getNode()); + return SDValue(); +} + SDValue DAGCombiner::visitLOAD(SDNode *N) { LoadSDNode *LD = cast<LoadSDNode>(N); SDValue Chain = LD->getChain(); @@ -12637,17 +13170,8 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { // If this load is directly stored, replace the load value with the stored // value. - // TODO: Handle store large -> read small portion. - // TODO: Handle TRUNCSTORE/LOADEXT - if (OptLevel != CodeGenOpt::None && - ISD::isNormalLoad(N) && !LD->isVolatile()) { - if (ISD::isNON_TRUNCStore(Chain.getNode())) { - StoreSDNode *PrevST = cast<StoreSDNode>(Chain); - if (PrevST->getBasePtr() == Ptr && - PrevST->getValue().getValueType() == N->getValueType(0)) - return CombineTo(N, PrevST->getOperand(1), Chain); - } - } + if (auto V = ForwardStoreValueToDirectLoad(LD)) + return V; // Try to infer better alignment information than the load already has. if (OptLevel != CodeGenOpt::None && LD->isUnindexed()) { @@ -13055,8 +13579,7 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices, // Sort the slices so that elements that are likely to be next to each // other in memory are next to each other in the list. - llvm::sort(LoadedSlices.begin(), LoadedSlices.end(), - [](const LoadedSlice &LHS, const LoadedSlice &RHS) { + llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) { assert(LHS.Origin == RHS.Origin && "Different bases not implemented."); return LHS.getOffsetFromBase() < RHS.getOffsetFromBase(); }); @@ -13689,7 +14212,7 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( SDValue Val = St->getValue(); // If constant is of the wrong type, convert it now. if (MemVT != Val.getValueType()) { - Val = peekThroughBitcast(Val); + Val = peekThroughBitcasts(Val); // Deal with constants of wrong size. if (ElementSizeBits != Val.getValueSizeInBits()) { EVT IntMemVT = @@ -13715,7 +14238,7 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( SmallVector<SDValue, 8> Ops; for (unsigned i = 0; i < NumStores; ++i) { StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); - SDValue Val = peekThroughBitcast(St->getValue()); + SDValue Val = peekThroughBitcasts(St->getValue()); // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of // type MemVT. If the underlying value is not the correct // type, but it is an extraction of an appropriate vector we @@ -13725,19 +14248,17 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( if ((MemVT != Val.getValueType()) && (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT || Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) { - SDValue Vec = Val.getOperand(0); EVT MemVTScalarTy = MemVT.getScalarType(); // We may need to add a bitcast here to get types to line up. - if (MemVTScalarTy != Vec.getValueType()) { - unsigned Elts = Vec.getValueType().getSizeInBits() / - MemVTScalarTy.getSizeInBits(); - EVT NewVecTy = - EVT::getVectorVT(*DAG.getContext(), MemVTScalarTy, Elts); - Vec = DAG.getBitcast(NewVecTy, Vec); + if (MemVTScalarTy != Val.getValueType().getScalarType()) { + Val = DAG.getBitcast(MemVT, Val); + } else { + unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR + : ISD::EXTRACT_VECTOR_ELT; + SDValue Vec = Val.getOperand(0); + SDValue Idx = Val.getOperand(1); + Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx); } - auto OpC = (MemVT.isVector()) ? ISD::EXTRACT_SUBVECTOR - : ISD::EXTRACT_VECTOR_ELT; - Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Val.getOperand(1)); } Ops.push_back(Val); } @@ -13762,7 +14283,7 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( StoreSDNode *St = cast<StoreSDNode>(StoreNodes[Idx].MemNode); SDValue Val = St->getValue(); - Val = peekThroughBitcast(Val); + Val = peekThroughBitcasts(Val); StoreInt <<= ElementSizeBits; if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) { StoreInt |= C->getAPIntValue() @@ -13825,7 +14346,7 @@ void DAGCombiner::getStoreMergeCandidates( BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG); EVT MemVT = St->getMemoryVT(); - SDValue Val = peekThroughBitcast(St->getValue()); + SDValue Val = peekThroughBitcasts(St->getValue()); // We must have a base and an offset. if (!BasePtr.getBase().getNode()) return; @@ -13859,7 +14380,7 @@ void DAGCombiner::getStoreMergeCandidates( int64_t &Offset) -> bool { if (Other->isVolatile() || Other->isIndexed()) return false; - SDValue Val = peekThroughBitcast(Other->getValue()); + SDValue Val = peekThroughBitcasts(Other->getValue()); // Allow merging constants of different types as integers. bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT()) : Other->getMemoryVT() != MemVT; @@ -13966,11 +14487,12 @@ bool DAGCombiner::checkMergeStoreCandidatesForDependencies( Worklist.push_back(RootNode); while (!Worklist.empty()) { auto N = Worklist.pop_back_val(); + if (!Visited.insert(N).second) + continue; // Already present in Visited. if (N->getOpcode() == ISD::TokenFactor) { for (SDValue Op : N->ops()) Worklist.push_back(Op.getNode()); } - Visited.insert(N); } // Don't count pruning nodes towards max. @@ -13983,14 +14505,14 @@ bool DAGCombiner::checkMergeStoreCandidatesForDependencies( // in candidate selection and can be // safely ignored // * Value (Op 1) -> Cycles may happen (e.g. through load chains) - // * Address (Op 2) -> Merged addresses may only vary by a fixed constant - // and so no cycles are possible. - // * (Op 3) -> appears to always be undef. Cannot be source of cycle. - // - // Thus we need only check predecessors of the value operands. - auto *Op = N->getOperand(1).getNode(); - if (Visited.insert(Op).second) - Worklist.push_back(Op); + // * Address (Op 2) -> Merged addresses may only vary by a fixed constant, + // but aren't necessarily fromt the same base node, so + // cycles possible (e.g. via indexed store). + // * (Op 3) -> Represents the pre or post-indexing offset (or undef for + // non-indexed stores). Not constant on all targets (e.g. ARM) + // and so can participate in a cycle. + for (unsigned j = 1; j < N->getNumOperands(); ++j) + Worklist.push_back(N->getOperand(j).getNode()); } // Search through DAG. We can stop early if we find a store node. for (unsigned i = 0; i < NumStores; ++i) @@ -14023,7 +14545,7 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { // Perform an early exit check. Do not bother looking at stored values that // are not constants, loads, or extracted vector elements. - SDValue StoredVal = peekThroughBitcast(St->getValue()); + SDValue StoredVal = peekThroughBitcasts(St->getValue()); bool IsLoadSrc = isa<LoadSDNode>(StoredVal); bool IsConstantSrc = isa<ConstantSDNode>(StoredVal) || isa<ConstantFPSDNode>(StoredVal); @@ -14044,10 +14566,9 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { // Sort the memory operands according to their distance from the // base pointer. - llvm::sort(StoreNodes.begin(), StoreNodes.end(), - [](MemOpLink LHS, MemOpLink RHS) { - return LHS.OffsetFromBase < RHS.OffsetFromBase; - }); + llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) { + return LHS.OffsetFromBase < RHS.OffsetFromBase; + }); // Store Merge attempts to merge the lowest stores. This generally // works out as if successful, as the remaining stores are checked @@ -14292,7 +14813,7 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { for (unsigned i = 0; i < NumConsecutiveStores; ++i) { StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); - SDValue Val = peekThroughBitcast(St->getValue()); + SDValue Val = peekThroughBitcasts(St->getValue()); LoadSDNode *Ld = cast<LoadSDNode>(Val); BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG); @@ -14640,8 +15161,13 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() && ST->isUnindexed()) { EVT SVT = Value.getOperand(0).getValueType(); + // If the store is volatile, we only want to change the store type if the + // resulting store is legal. Otherwise we might increase the number of + // memory accesses. We don't care if the original type was legal or not + // as we assume software couldn't rely on the number of accesses of an + // illegal type. if (((!LegalOperations && !ST->isVolatile()) || - TLI.isOperationLegalOrCustom(ISD::STORE, SVT)) && + TLI.isOperationLegal(ISD::STORE, SVT)) && TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT)) { unsigned OrigAlign = ST->getAlignment(); bool Fast = false; @@ -14692,7 +15218,9 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { // FIXME: is there such a thing as a truncating indexed store? if (ST->isTruncatingStore() && ST->isUnindexed() && - Value.getValueType().isInteger()) { + Value.getValueType().isInteger() && + (!isa<ConstantSDNode>(Value) || + !cast<ConstantSDNode>(Value)->isOpaque())) { // See if we can simplify the input to this truncstore with knowledge that // only the low bits are being used. For example: // "truncstore (or (shl x, 8), y), i8" -> "truncstore y, i8" @@ -14976,6 +15504,7 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { return InVec; EVT VT = InVec.getValueType(); + unsigned NumElts = VT.getVectorNumElements(); // Remove redundant insertions: // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x @@ -14983,12 +15512,19 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1)) return InVec; - // We must know which element is being inserted for folds below here. auto *IndexC = dyn_cast<ConstantSDNode>(EltNo); - if (!IndexC) + if (!IndexC) { + // If this is variable insert to undef vector, it might be better to splat: + // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... > + if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) { + SmallVector<SDValue, 8> Ops(NumElts, InVal); + return DAG.getBuildVector(VT, DL, Ops); + } return SDValue(); - unsigned Elt = IndexC->getZExtValue(); + } + // We must know which element is being inserted for folds below here. + unsigned Elt = IndexC->getZExtValue(); if (SDValue Shuf = combineInsertEltToShuffle(N, Elt)) return Shuf; @@ -15026,11 +15562,11 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { Ops.append(InVec.getNode()->op_begin(), InVec.getNode()->op_end()); } else if (InVec.isUndef()) { - unsigned NElts = VT.getVectorNumElements(); - Ops.append(NElts, DAG.getUNDEF(InVal.getValueType())); + Ops.append(NumElts, DAG.getUNDEF(InVal.getValueType())); } else { return SDValue(); } + assert(Ops.size() == NumElts && "Unexpected vector size"); // Insert the element if (Elt < Ops.size()) { @@ -15044,8 +15580,9 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { return DAG.getBuildVector(VT, DL, Ops); } -SDValue DAGCombiner::ReplaceExtractVectorEltOfLoadWithNarrowedLoad( - SDNode *EVE, EVT InVecVT, SDValue EltNo, LoadSDNode *OriginalLoad) { +SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT, + SDValue EltNo, + LoadSDNode *OriginalLoad) { assert(!OriginalLoad->isVolatile()); EVT ResultVT = EVE->getValueType(0); @@ -15127,70 +15664,132 @@ SDValue DAGCombiner::ReplaceExtractVectorEltOfLoadWithNarrowedLoad( return SDValue(EVE, 0); } +/// Transform a vector binary operation into a scalar binary operation by moving +/// the math/logic after an extract element of a vector. +static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG, + bool LegalOperations) { + SDValue Vec = ExtElt->getOperand(0); + SDValue Index = ExtElt->getOperand(1); + auto *IndexC = dyn_cast<ConstantSDNode>(Index); + if (!IndexC || !ISD::isBinaryOp(Vec.getNode()) || !Vec.hasOneUse()) + return SDValue(); + + // Targets may want to avoid this to prevent an expensive register transfer. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!TLI.shouldScalarizeBinop(Vec)) + return SDValue(); + + // Extracting an element of a vector constant is constant-folded, so this + // transform is just replacing a vector op with a scalar op while moving the + // extract. + SDValue Op0 = Vec.getOperand(0); + SDValue Op1 = Vec.getOperand(1); + if (isAnyConstantBuildVector(Op0, true) || + isAnyConstantBuildVector(Op1, true)) { + // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C' + // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC) + SDLoc DL(ExtElt); + EVT VT = ExtElt->getValueType(0); + SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index); + SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index); + return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1); + } + + return SDValue(); +} + SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { - // (vextract (scalar_to_vector val, 0) -> val - SDValue InVec = N->getOperand(0); - EVT VT = InVec.getValueType(); - EVT NVT = N->getValueType(0); + SDValue VecOp = N->getOperand(0); + SDValue Index = N->getOperand(1); + EVT ScalarVT = N->getValueType(0); + EVT VecVT = VecOp.getValueType(); + if (VecOp.isUndef()) + return DAG.getUNDEF(ScalarVT); - if (InVec.isUndef()) - return DAG.getUNDEF(NVT); + // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val + // + // This only really matters if the index is non-constant since other combines + // on the constant elements already work. + SDLoc DL(N); + if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT && + Index == VecOp.getOperand(2)) { + SDValue Elt = VecOp.getOperand(1); + return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt; + } - if (InVec.getOpcode() == ISD::SCALAR_TO_VECTOR) { + // (vextract (scalar_to_vector val, 0) -> val + if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) { // Check if the result type doesn't match the inserted element type. A // SCALAR_TO_VECTOR may truncate the inserted element and the // EXTRACT_VECTOR_ELT may widen the extracted vector. - SDValue InOp = InVec.getOperand(0); - if (InOp.getValueType() != NVT) { - assert(InOp.getValueType().isInteger() && NVT.isInteger()); - return DAG.getSExtOrTrunc(InOp, SDLoc(InVec), NVT); + SDValue InOp = VecOp.getOperand(0); + if (InOp.getValueType() != ScalarVT) { + assert(InOp.getValueType().isInteger() && ScalarVT.isInteger()); + return DAG.getSExtOrTrunc(InOp, DL, ScalarVT); } return InOp; } - SDValue EltNo = N->getOperand(1); - ConstantSDNode *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo); - // extract_vector_elt of out-of-bounds element -> UNDEF - if (ConstEltNo && ConstEltNo->getAPIntValue().uge(VT.getVectorNumElements())) - return DAG.getUNDEF(NVT); + auto *IndexC = dyn_cast<ConstantSDNode>(Index); + unsigned NumElts = VecVT.getVectorNumElements(); + if (IndexC && IndexC->getAPIntValue().uge(NumElts)) + return DAG.getUNDEF(ScalarVT); // extract_vector_elt (build_vector x, y), 1 -> y - if (ConstEltNo && - InVec.getOpcode() == ISD::BUILD_VECTOR && - TLI.isTypeLegal(VT) && - (InVec.hasOneUse() || - TLI.aggressivelyPreferBuildVectorSources(VT))) { - SDValue Elt = InVec.getOperand(ConstEltNo->getZExtValue()); + if (IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR && + TLI.isTypeLegal(VecVT) && + (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) { + SDValue Elt = VecOp.getOperand(IndexC->getZExtValue()); EVT InEltVT = Elt.getValueType(); // Sometimes build_vector's scalar input types do not match result type. - if (NVT == InEltVT) + if (ScalarVT == InEltVT) return Elt; // TODO: It may be useful to truncate if free if the build_vector implicitly // converts. } - // extract_vector_elt (v2i32 (bitcast i64:x)), EltTrunc -> i32 (trunc i64:x) - bool isLE = DAG.getDataLayout().isLittleEndian(); - unsigned EltTrunc = isLE ? 0 : VT.getVectorNumElements() - 1; - if (ConstEltNo && InVec.getOpcode() == ISD::BITCAST && InVec.hasOneUse() && - ConstEltNo->getZExtValue() == EltTrunc && VT.isInteger()) { - SDValue BCSrc = InVec.getOperand(0); - if (BCSrc.getValueType().isScalarInteger()) - return DAG.getNode(ISD::TRUNCATE, SDLoc(N), NVT, BCSrc); + // TODO: These transforms should not require the 'hasOneUse' restriction, but + // there are regressions on multiple targets without it. We can end up with a + // mess of scalar and vector code if we reduce only part of the DAG to scalar. + if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() && + VecOp.hasOneUse()) { + // The vector index of the LSBs of the source depend on the endian-ness. + bool IsLE = DAG.getDataLayout().isLittleEndian(); + unsigned ExtractIndex = IndexC->getZExtValue(); + // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x) + unsigned BCTruncElt = IsLE ? 0 : NumElts - 1; + SDValue BCSrc = VecOp.getOperand(0); + if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger()) + return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, BCSrc); + + if (LegalTypes && BCSrc.getValueType().isInteger() && + BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) { + // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt --> + // trunc i64 X to i32 + SDValue X = BCSrc.getOperand(0); + assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() && + "Extract element and scalar to vector can't change element type " + "from FP to integer."); + unsigned XBitWidth = X.getValueSizeInBits(); + unsigned VecEltBitWidth = VecVT.getScalarSizeInBits(); + BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1; + + // An extract element return value type can be wider than its vector + // operand element type. In that case, the high bits are undefined, so + // it's possible that we may need to extend rather than truncate. + if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) { + assert(XBitWidth % VecEltBitWidth == 0 && + "Scalar bitwidth must be a multiple of vector element bitwidth"); + return DAG.getAnyExtOrTrunc(X, DL, ScalarVT); + } + } } - // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val - // - // This only really matters if the index is non-constant since other combines - // on the constant elements already work. - if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && - EltNo == InVec.getOperand(2)) { - SDValue Elt = InVec.getOperand(1); - return VT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, SDLoc(N), NVT) : Elt; - } + if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations)) + return BO; // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT. // We only perform this optimization before the op legalization phase because @@ -15198,30 +15797,29 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { // patterns. For example on AVX, extracting elements from a wide vector // without using extract_subvector. However, if we can find an underlying // scalar value, then we can always use that. - if (ConstEltNo && InVec.getOpcode() == ISD::VECTOR_SHUFFLE) { - int NumElem = VT.getVectorNumElements(); - ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(InVec); + if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) { + auto *Shuf = cast<ShuffleVectorSDNode>(VecOp); // Find the new index to extract from. - int OrigElt = SVOp->getMaskElt(ConstEltNo->getZExtValue()); + int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue()); // Extracting an undef index is undef. if (OrigElt == -1) - return DAG.getUNDEF(NVT); + return DAG.getUNDEF(ScalarVT); // Select the right vector half to extract from. SDValue SVInVec; - if (OrigElt < NumElem) { - SVInVec = InVec->getOperand(0); + if (OrigElt < (int)NumElts) { + SVInVec = VecOp.getOperand(0); } else { - SVInVec = InVec->getOperand(1); - OrigElt -= NumElem; + SVInVec = VecOp.getOperand(1); + OrigElt -= NumElts; } if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) { SDValue InOp = SVInVec.getOperand(OrigElt); - if (InOp.getValueType() != NVT) { - assert(InOp.getValueType().isInteger() && NVT.isInteger()); - InOp = DAG.getSExtOrTrunc(InOp, SDLoc(SVInVec), NVT); + if (InOp.getValueType() != ScalarVT) { + assert(InOp.getValueType().isInteger() && ScalarVT.isInteger()); + InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT); } return InOp; @@ -15232,136 +15830,131 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { if (!LegalOperations || // FIXME: Should really be just isOperationLegalOrCustom. - TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VT) || - TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VT)) { + TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) || + TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) { EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout()); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(N), NVT, SVInVec, - DAG.getConstant(OrigElt, SDLoc(SVOp), IndexTy)); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec, + DAG.getConstant(OrigElt, DL, IndexTy)); } } // If only EXTRACT_VECTOR_ELT nodes use the source vector we can // simplify it based on the (valid) extraction indices. - if (llvm::all_of(InVec->uses(), [&](SDNode *Use) { + if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) { return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT && - Use->getOperand(0) == InVec && + Use->getOperand(0) == VecOp && isa<ConstantSDNode>(Use->getOperand(1)); })) { - APInt DemandedElts = APInt::getNullValue(VT.getVectorNumElements()); - for (SDNode *Use : InVec->uses()) { + APInt DemandedElts = APInt::getNullValue(NumElts); + for (SDNode *Use : VecOp->uses()) { auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1)); - if (CstElt->getAPIntValue().ult(VT.getVectorNumElements())) + if (CstElt->getAPIntValue().ult(NumElts)) DemandedElts.setBit(CstElt->getZExtValue()); } - if (SimplifyDemandedVectorElts(InVec, DemandedElts, true)) + if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) { + // We simplified the vector operand of this extract element. If this + // extract is not dead, visit it again so it is folded properly. + if (N->getOpcode() != ISD::DELETED_NODE) + AddToWorklist(N); return SDValue(N, 0); + } } - bool BCNumEltsChanged = false; - EVT ExtVT = VT.getVectorElementType(); - EVT LVT = ExtVT; - + // Everything under here is trying to match an extract of a loaded value. // If the result of load has to be truncated, then it's not necessarily // profitable. - if (NVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, NVT)) + bool BCNumEltsChanged = false; + EVT ExtVT = VecVT.getVectorElementType(); + EVT LVT = ExtVT; + if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT)) return SDValue(); - if (InVec.getOpcode() == ISD::BITCAST) { + if (VecOp.getOpcode() == ISD::BITCAST) { // Don't duplicate a load with other uses. - if (!InVec.hasOneUse()) + if (!VecOp.hasOneUse()) return SDValue(); - EVT BCVT = InVec.getOperand(0).getValueType(); + EVT BCVT = VecOp.getOperand(0).getValueType(); if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType())) return SDValue(); - if (VT.getVectorNumElements() != BCVT.getVectorNumElements()) + if (NumElts != BCVT.getVectorNumElements()) BCNumEltsChanged = true; - InVec = InVec.getOperand(0); + VecOp = VecOp.getOperand(0); ExtVT = BCVT.getVectorElementType(); } - // (vextract (vN[if]M load $addr), i) -> ([if]M load $addr + i * size) - if (!LegalOperations && !ConstEltNo && InVec.hasOneUse() && - ISD::isNormalLoad(InVec.getNode()) && - !N->getOperand(1)->hasPredecessor(InVec.getNode())) { - SDValue Index = N->getOperand(1); - if (LoadSDNode *OrigLoad = dyn_cast<LoadSDNode>(InVec)) { - if (!OrigLoad->isVolatile()) { - return ReplaceExtractVectorEltOfLoadWithNarrowedLoad(N, VT, Index, - OrigLoad); - } - } + // extract (vector load $addr), i --> load $addr + i * size + if (!LegalOperations && !IndexC && VecOp.hasOneUse() && + ISD::isNormalLoad(VecOp.getNode()) && + !Index->hasPredecessor(VecOp.getNode())) { + auto *VecLoad = dyn_cast<LoadSDNode>(VecOp); + if (VecLoad && !VecLoad->isVolatile()) + return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad); } // Perform only after legalization to ensure build_vector / vector_shuffle // optimizations have already been done. - if (!LegalOperations) return SDValue(); + if (!LegalOperations || !IndexC) + return SDValue(); // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size) // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size) // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr) + int Elt = IndexC->getZExtValue(); + LoadSDNode *LN0 = nullptr; + if (ISD::isNormalLoad(VecOp.getNode())) { + LN0 = cast<LoadSDNode>(VecOp); + } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR && + VecOp.getOperand(0).getValueType() == ExtVT && + ISD::isNormalLoad(VecOp.getOperand(0).getNode())) { + // Don't duplicate a load with other uses. + if (!VecOp.hasOneUse()) + return SDValue(); - if (ConstEltNo) { - int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue(); + LN0 = cast<LoadSDNode>(VecOp.getOperand(0)); + } + if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) { + // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1) + // => + // (load $addr+1*size) - LoadSDNode *LN0 = nullptr; - const ShuffleVectorSDNode *SVN = nullptr; - if (ISD::isNormalLoad(InVec.getNode())) { - LN0 = cast<LoadSDNode>(InVec); - } else if (InVec.getOpcode() == ISD::SCALAR_TO_VECTOR && - InVec.getOperand(0).getValueType() == ExtVT && - ISD::isNormalLoad(InVec.getOperand(0).getNode())) { - // Don't duplicate a load with other uses. - if (!InVec.hasOneUse()) - return SDValue(); + // Don't duplicate a load with other uses. + if (!VecOp.hasOneUse()) + return SDValue(); + + // If the bit convert changed the number of elements, it is unsafe + // to examine the mask. + if (BCNumEltsChanged) + return SDValue(); - LN0 = cast<LoadSDNode>(InVec.getOperand(0)); - } else if ((SVN = dyn_cast<ShuffleVectorSDNode>(InVec))) { - // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1) - // => - // (load $addr+1*size) + // Select the input vector, guarding against out of range extract vector. + int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt); + VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1); + if (VecOp.getOpcode() == ISD::BITCAST) { // Don't duplicate a load with other uses. - if (!InVec.hasOneUse()) - return SDValue(); - - // If the bit convert changed the number of elements, it is unsafe - // to examine the mask. - if (BCNumEltsChanged) + if (!VecOp.hasOneUse()) return SDValue(); - // Select the input vector, guarding against out of range extract vector. - unsigned NumElems = VT.getVectorNumElements(); - int Idx = (Elt > (int)NumElems) ? -1 : SVN->getMaskElt(Elt); - InVec = (Idx < (int)NumElems) ? InVec.getOperand(0) : InVec.getOperand(1); - - if (InVec.getOpcode() == ISD::BITCAST) { - // Don't duplicate a load with other uses. - if (!InVec.hasOneUse()) - return SDValue(); - - InVec = InVec.getOperand(0); - } - if (ISD::isNormalLoad(InVec.getNode())) { - LN0 = cast<LoadSDNode>(InVec); - Elt = (Idx < (int)NumElems) ? Idx : Idx - (int)NumElems; - EltNo = DAG.getConstant(Elt, SDLoc(EltNo), EltNo.getValueType()); - } + VecOp = VecOp.getOperand(0); } + if (ISD::isNormalLoad(VecOp.getNode())) { + LN0 = cast<LoadSDNode>(VecOp); + Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts; + Index = DAG.getConstant(Elt, DL, Index.getValueType()); + } + } - // Make sure we found a non-volatile load and the extractelement is - // the only use. - if (!LN0 || !LN0->hasNUsesOfValue(1,0) || LN0->isVolatile()) - return SDValue(); - - // If Idx was -1 above, Elt is going to be -1, so just return undef. - if (Elt == -1) - return DAG.getUNDEF(LVT); + // Make sure we found a non-volatile load and the extractelement is + // the only use. + if (!LN0 || !LN0->hasNUsesOfValue(1,0) || LN0->isVolatile()) + return SDValue(); - return ReplaceExtractVectorEltOfLoadWithNarrowedLoad(N, VT, EltNo, LN0); - } + // If Idx was -1 above, Elt is going to be -1, so just return undef. + if (Elt == -1) + return DAG.getUNDEF(LVT); - return SDValue(); + return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0); } // Simplify (build_vec (ext )) to (bitcast (build_vec )) @@ -15477,77 +16070,6 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) { return DAG.getBitcast(VT, BV); } -SDValue DAGCombiner::reduceBuildVecConvertToConvertBuildVec(SDNode *N) { - EVT VT = N->getValueType(0); - - unsigned NumInScalars = N->getNumOperands(); - SDLoc DL(N); - - EVT SrcVT = MVT::Other; - unsigned Opcode = ISD::DELETED_NODE; - unsigned NumDefs = 0; - - for (unsigned i = 0; i != NumInScalars; ++i) { - SDValue In = N->getOperand(i); - unsigned Opc = In.getOpcode(); - - if (Opc == ISD::UNDEF) - continue; - - // If all scalar values are floats and converted from integers. - if (Opcode == ISD::DELETED_NODE && - (Opc == ISD::UINT_TO_FP || Opc == ISD::SINT_TO_FP)) { - Opcode = Opc; - } - - if (Opc != Opcode) - return SDValue(); - - EVT InVT = In.getOperand(0).getValueType(); - - // If all scalar values are typed differently, bail out. It's chosen to - // simplify BUILD_VECTOR of integer types. - if (SrcVT == MVT::Other) - SrcVT = InVT; - if (SrcVT != InVT) - return SDValue(); - NumDefs++; - } - - // If the vector has just one element defined, it's not worth to fold it into - // a vectorized one. - if (NumDefs < 2) - return SDValue(); - - assert((Opcode == ISD::UINT_TO_FP || Opcode == ISD::SINT_TO_FP) - && "Should only handle conversion from integer to float."); - assert(SrcVT != MVT::Other && "Cannot determine source type!"); - - EVT NVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumInScalars); - - if (!TLI.isOperationLegalOrCustom(Opcode, NVT)) - return SDValue(); - - // Just because the floating-point vector type is legal does not necessarily - // mean that the corresponding integer vector type is. - if (!isTypeLegal(NVT)) - return SDValue(); - - SmallVector<SDValue, 8> Opnds; - for (unsigned i = 0; i != NumInScalars; ++i) { - SDValue In = N->getOperand(i); - - if (In.isUndef()) - Opnds.push_back(DAG.getUNDEF(SrcVT)); - else - Opnds.push_back(In.getOperand(0)); - } - SDValue BV = DAG.getBuildVector(NVT, DL, Opnds); - AddToWorklist(BV.getNode()); - - return DAG.getNode(Opcode, DL, VT, BV); -} - SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N, ArrayRef<int> VectorMask, SDValue VecIn1, SDValue VecIn2, @@ -15669,6 +16191,78 @@ SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N, return Shuffle; } +static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) { + assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector"); + + // First, determine where the build vector is not undef. + // TODO: We could extend this to handle zero elements as well as undefs. + int NumBVOps = BV->getNumOperands(); + int ZextElt = -1; + for (int i = 0; i != NumBVOps; ++i) { + SDValue Op = BV->getOperand(i); + if (Op.isUndef()) + continue; + if (ZextElt == -1) + ZextElt = i; + else + return SDValue(); + } + // Bail out if there's no non-undef element. + if (ZextElt == -1) + return SDValue(); + + // The build vector contains some number of undef elements and exactly + // one other element. That other element must be a zero-extended scalar + // extracted from a vector at a constant index to turn this into a shuffle. + // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND. + SDValue Zext = BV->getOperand(ZextElt); + if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() || + Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT || + !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1))) + return SDValue(); + + // The zero-extend must be a multiple of the source size. + SDValue Extract = Zext.getOperand(0); + unsigned DestSize = Zext.getValueSizeInBits(); + unsigned SrcSize = Extract.getValueSizeInBits(); + if (DestSize % SrcSize != 0) + return SDValue(); + + // Create a shuffle mask that will combine the extracted element with zeros + // and undefs. + int ZextRatio = DestSize / SrcSize; + int NumMaskElts = NumBVOps * ZextRatio; + SmallVector<int, 32> ShufMask(NumMaskElts, -1); + for (int i = 0; i != NumMaskElts; ++i) { + if (i / ZextRatio == ZextElt) { + // The low bits of the (potentially translated) extracted element map to + // the source vector. The high bits map to zero. We will use a zero vector + // as the 2nd source operand of the shuffle, so use the 1st element of + // that vector (mask value is number-of-elements) for the high bits. + if (i % ZextRatio == 0) + ShufMask[i] = Extract.getConstantOperandVal(1); + else + ShufMask[i] = NumMaskElts; + } + + // Undef elements of the build vector remain undef because we initialize + // the shuffle mask with -1. + } + + // Turn this into a shuffle with zero if that's legal. + EVT VecVT = Extract.getOperand(0).getValueType(); + if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(ShufMask, VecVT)) + return SDValue(); + + // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... --> + // bitcast (shuffle V, ZeroVec, VectorMask) + SDLoc DL(BV); + SDValue ZeroVec = DAG.getConstant(0, DL, VecVT); + SDValue Shuf = DAG.getVectorShuffle(VecVT, DL, Extract.getOperand(0), ZeroVec, + ShufMask); + return DAG.getBitcast(BV->getValueType(0), Shuf); +} + // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT // operations. If the types of the vectors we're extracting from allow it, // turn this into a vector_shuffle node. @@ -15680,6 +16274,9 @@ SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) { if (!isTypeLegal(VT)) return SDValue(); + if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG)) + return V; + // May only combine to shuffle after legalize if shuffle is legal. if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT)) return SDValue(); @@ -15943,7 +16540,7 @@ SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) { // TODO: Maybe this is useful for non-splat too? if (!LegalOperations) { if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) { - Splat = peekThroughBitcast(Splat); + Splat = peekThroughBitcasts(Splat); EVT SrcVT = Splat.getValueType(); if (SrcVT.isVector()) { unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements(); @@ -15994,9 +16591,6 @@ SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) { if (SDValue V = reduceBuildVecExtToExtBuildVec(N)) return V; - if (SDValue V = reduceBuildVecConvertToConvertBuildVec(N)) - return V; - if (SDValue V = reduceBuildVecToShuffle(N)) return V; @@ -16078,8 +16672,7 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) { SmallVector<int, 8> Mask; for (SDValue Op : N->ops()) { - // Peek through any bitcast. - Op = peekThroughBitcast(Op); + Op = peekThroughBitcasts(Op); // UNDEF nodes convert to UNDEF shuffle mask values. if (Op.isUndef()) { @@ -16096,9 +16689,7 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) { // We want the EVT of the original extraction to correctly scale the // extraction index. EVT ExtVT = ExtVec.getValueType(); - - // Peek through any bitcast. - ExtVec = peekThroughBitcast(ExtVec); + ExtVec = peekThroughBitcasts(ExtVec); // UNDEF nodes convert to UNDEF shuffle mask values. if (ExtVec.isUndef()) { @@ -16162,11 +16753,19 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { SDValue In = N->getOperand(0); assert(In.getValueType().isVector() && "Must concat vectors"); - // Transform: concat_vectors(scalar, undef) -> scalar_to_vector(sclr). - if (In->getOpcode() == ISD::BITCAST && - !In->getOperand(0).getValueType().isVector()) { - SDValue Scalar = In->getOperand(0); + SDValue Scalar = peekThroughOneUseBitcasts(In); + // concat_vectors(scalar_to_vector(scalar), undef) -> + // scalar_to_vector(scalar) + if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR && + Scalar.hasOneUse()) { + EVT SVT = Scalar.getValueType().getVectorElementType(); + if (SVT == Scalar.getOperand(0).getValueType()) + Scalar = Scalar.getOperand(0); + } + + // concat_vectors(scalar, undef) -> scalar_to_vector(scalar) + if (!Scalar.getValueType().isVector()) { // If the bitcast type isn't legal, it might be a trunc of a legal type; // look through the trunc so we can still do the transform: // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar) @@ -16175,7 +16774,7 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { TLI.isTypeLegal(Scalar->getOperand(0).getValueType())) Scalar = Scalar->getOperand(0); - EVT SclTy = Scalar->getValueType(0); + EVT SclTy = Scalar.getValueType(); if (!SclTy.isFloatingPoint() && !SclTy.isInteger()) return SDValue(); @@ -16303,60 +16902,93 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { return SDValue(); } -/// If we are extracting a subvector produced by a wide binary operator with at -/// at least one operand that was the result of a vector concatenation, then try -/// to use the narrow vector operands directly to avoid the concatenation and -/// extraction. +/// If we are extracting a subvector produced by a wide binary operator try +/// to use a narrow binary operator and/or avoid concatenation and extraction. static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) { // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share // some of these bailouts with other transforms. // The extract index must be a constant, so we can map it to a concat operand. - auto *ExtractIndex = dyn_cast<ConstantSDNode>(Extract->getOperand(1)); - if (!ExtractIndex) - return SDValue(); - - // Only handle the case where we are doubling and then halving. A larger ratio - // may require more than two narrow binops to replace the wide binop. - EVT VT = Extract->getValueType(0); - unsigned NumElems = VT.getVectorNumElements(); - assert((ExtractIndex->getZExtValue() % NumElems) == 0 && - "Extract index is not a multiple of the vector length."); - if (Extract->getOperand(0).getValueSizeInBits() != VT.getSizeInBits() * 2) + auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1)); + if (!ExtractIndexC) return SDValue(); // We are looking for an optionally bitcasted wide vector binary operator // feeding an extract subvector. - SDValue BinOp = peekThroughBitcast(Extract->getOperand(0)); - - // TODO: The motivating case for this transform is an x86 AVX1 target. That - // target has temptingly almost legal versions of bitwise logic ops in 256-bit - // flavors, but no other 256-bit integer support. This could be extended to - // handle any binop, but that may require fixing/adding other folds to avoid - // codegen regressions. - unsigned BOpcode = BinOp.getOpcode(); - if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR) + SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0)); + if (!ISD::isBinaryOp(BinOp.getNode())) return SDValue(); - // The binop must be a vector type, so we can chop it in half. + // The binop must be a vector type, so we can extract some fraction of it. EVT WideBVT = BinOp.getValueType(); if (!WideBVT.isVector()) return SDValue(); + EVT VT = Extract->getValueType(0); + unsigned ExtractIndex = ExtractIndexC->getZExtValue(); + assert(ExtractIndex % VT.getVectorNumElements() == 0 && + "Extract index is not a multiple of the vector length."); + + // Bail out if this is not a proper multiple width extraction. + unsigned WideWidth = WideBVT.getSizeInBits(); + unsigned NarrowWidth = VT.getSizeInBits(); + if (WideWidth % NarrowWidth != 0) + return SDValue(); + + // Bail out if we are extracting a fraction of a single operation. This can + // occur because we potentially looked through a bitcast of the binop. + unsigned NarrowingRatio = WideWidth / NarrowWidth; + unsigned WideNumElts = WideBVT.getVectorNumElements(); + if (WideNumElts % NarrowingRatio != 0) + return SDValue(); + // Bail out if the target does not support a narrower version of the binop. EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(), - WideBVT.getVectorNumElements() / 2); + WideNumElts / NarrowingRatio); + unsigned BOpcode = BinOp.getOpcode(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT)) return SDValue(); - // Peek through bitcasts of the binary operator operands if needed. - SDValue LHS = peekThroughBitcast(BinOp.getOperand(0)); - SDValue RHS = peekThroughBitcast(BinOp.getOperand(1)); + // If extraction is cheap, we don't need to look at the binop operands + // for concat ops. The narrow binop alone makes this transform profitable. + // We can't just reuse the original extract index operand because we may have + // bitcasted. + unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements(); + unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements(); + EVT ExtBOIdxVT = Extract->getOperand(1).getValueType(); + if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) && + BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) { + // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N) + SDLoc DL(Extract); + SDValue NewExtIndex = DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT); + SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, + BinOp.getOperand(0), NewExtIndex); + SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, + BinOp.getOperand(1), NewExtIndex); + SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y, + BinOp.getNode()->getFlags()); + return DAG.getBitcast(VT, NarrowBinOp); + } + + // Only handle the case where we are doubling and then halving. A larger ratio + // may require more than two narrow binops to replace the wide binop. + if (NarrowingRatio != 2) + return SDValue(); + + // TODO: The motivating case for this transform is an x86 AVX1 target. That + // target has temptingly almost legal versions of bitwise logic ops in 256-bit + // flavors, but no other 256-bit integer support. This could be extended to + // handle any binop, but that may require fixing/adding other folds to avoid + // codegen regressions. + if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR) + return SDValue(); // We need at least one concatenation operation of a binop operand to make // this transform worthwhile. The concat must double the input vector sizes. // TODO: Should we also handle INSERT_SUBVECTOR patterns? + SDValue LHS = peekThroughBitcasts(BinOp.getOperand(0)); + SDValue RHS = peekThroughBitcasts(BinOp.getOperand(1)); bool ConcatL = LHS.getOpcode() == ISD::CONCAT_VECTORS && LHS.getNumOperands() == 2; bool ConcatR = @@ -16365,11 +16997,7 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) { return SDValue(); // If one of the binop operands was not the result of a concat, we must - // extract a half-sized operand for our new narrow binop. We can't just reuse - // the original extract index operand because we may have bitcasted. - unsigned ConcatOpNum = ExtractIndex->getZExtValue() / NumElems; - unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements(); - EVT ExtBOIdxVT = Extract->getOperand(1).getValueType(); + // extract a half-sized operand for our new narrow binop. SDLoc DL(Extract); // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN @@ -16397,17 +17025,19 @@ static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) { if (DAG.getDataLayout().isBigEndian()) return SDValue(); - // TODO: The one-use check is overly conservative. Check the cost of the - // extract instead or remove that condition entirely. auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0)); auto *ExtIdx = dyn_cast<ConstantSDNode>(Extract->getOperand(1)); - if (!Ld || !Ld->hasOneUse() || Ld->getExtensionType() || Ld->isVolatile() || - !ExtIdx) + if (!Ld || Ld->getExtensionType() || Ld->isVolatile() || !ExtIdx) + return SDValue(); + + // Allow targets to opt-out. + EVT VT = Extract->getValueType(0); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT)) return SDValue(); // The narrow load will be offset from the base address of the old load if // we are extracting from something besides index 0 (little-endian). - EVT VT = Extract->getValueType(0); SDLoc DL(Extract); SDValue BaseAddr = Ld->getOperand(1); unsigned Offset = ExtIdx->getZExtValue() * VT.getScalarType().getStoreSize(); @@ -16440,9 +17070,9 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) { // Vi if possible // Only operand 0 is checked as 'concat' assumes all inputs of the same // type. - if (V->getOpcode() == ISD::CONCAT_VECTORS && + if (V.getOpcode() == ISD::CONCAT_VECTORS && isa<ConstantSDNode>(N->getOperand(1)) && - V->getOperand(0).getValueType() == NVT) { + V.getOperand(0).getValueType() == NVT) { unsigned Idx = N->getConstantOperandVal(1); unsigned NumElems = NVT.getVectorNumElements(); assert((Idx % NumElems) == 0 && @@ -16450,13 +17080,12 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) { return V->getOperand(Idx / NumElems); } - // Skip bitcasting - V = peekThroughBitcast(V); + V = peekThroughBitcasts(V); // If the input is a build vector. Try to make a smaller build vector. - if (V->getOpcode() == ISD::BUILD_VECTOR) { + if (V.getOpcode() == ISD::BUILD_VECTOR) { if (auto *Idx = dyn_cast<ConstantSDNode>(N->getOperand(1))) { - EVT InVT = V->getValueType(0); + EVT InVT = V.getValueType(); unsigned ExtractSize = NVT.getSizeInBits(); unsigned EltSize = InVT.getScalarSizeInBits(); // Only do this if we won't split any elements. @@ -16489,16 +17118,16 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) { } } - if (V->getOpcode() == ISD::INSERT_SUBVECTOR) { + if (V.getOpcode() == ISD::INSERT_SUBVECTOR) { // Handle only simple case where vector being inserted and vector // being extracted are of same size. - EVT SmallVT = V->getOperand(1).getValueType(); + EVT SmallVT = V.getOperand(1).getValueType(); if (!NVT.bitsEq(SmallVT)) return SDValue(); // Only handle cases where both indexes are constants. - ConstantSDNode *ExtIdx = dyn_cast<ConstantSDNode>(N->getOperand(1)); - ConstantSDNode *InsIdx = dyn_cast<ConstantSDNode>(V->getOperand(2)); + auto *ExtIdx = dyn_cast<ConstantSDNode>(N->getOperand(1)); + auto *InsIdx = dyn_cast<ConstantSDNode>(V.getOperand(2)); if (InsIdx && ExtIdx) { // Combine: @@ -16508,11 +17137,11 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) { // otherwise => (extract_subvec V1, ExtIdx) if (InsIdx->getZExtValue() * SmallVT.getScalarSizeInBits() == ExtIdx->getZExtValue() * NVT.getScalarSizeInBits()) - return DAG.getBitcast(NVT, V->getOperand(1)); + return DAG.getBitcast(NVT, V.getOperand(1)); return DAG.getNode( ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, - DAG.getBitcast(N->getOperand(0).getValueType(), V->getOperand(0)), - N->getOperand(1)); + DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)), + N->getOperand(1)); } } @@ -16613,14 +17242,17 @@ static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN, SDValue N0 = SVN->getOperand(0); SDValue N1 = SVN->getOperand(1); - if (!N0->hasOneUse() || !N1->hasOneUse()) + if (!N0->hasOneUse()) return SDValue(); // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as // discussed above. if (!N1.isUndef()) { - bool N0AnyConst = isAnyConstantBuildVector(N0.getNode()); - bool N1AnyConst = isAnyConstantBuildVector(N1.getNode()); + if (!N1->hasOneUse()) + return SDValue(); + + bool N0AnyConst = isAnyConstantBuildVector(N0); + bool N1AnyConst = isAnyConstantBuildVector(N1); if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode())) return SDValue(); if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode())) @@ -16686,8 +17318,7 @@ static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN, static SDValue combineShuffleToVectorExtend(ShuffleVectorSDNode *SVN, SelectionDAG &DAG, const TargetLowering &TLI, - bool LegalOperations, - bool LegalTypes) { + bool LegalOperations) { EVT VT = SVN->getValueType(0); bool IsBigEndian = DAG.getDataLayout().isBigEndian(); @@ -16723,11 +17354,14 @@ static SDValue combineShuffleToVectorExtend(ShuffleVectorSDNode *SVN, EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale); EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale); - if (!LegalTypes || TLI.isTypeLegal(OutVT)) + // Never create an illegal type. Only create unsupported operations if we + // are pre-legalization. + if (TLI.isTypeLegal(OutVT)) if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND_VECTOR_INREG, OutVT)) return DAG.getBitcast(VT, - DAG.getAnyExtendVectorInReg(N0, SDLoc(SVN), OutVT)); + DAG.getNode(ISD::ANY_EXTEND_VECTOR_INREG, + SDLoc(SVN), OutVT, N0)); } return SDValue(); @@ -16747,7 +17381,7 @@ static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN, if (!VT.isInteger() || IsBigEndian) return SDValue(); - SDValue N0 = peekThroughBitcast(SVN->getOperand(0)); + SDValue N0 = peekThroughBitcasts(SVN->getOperand(0)); unsigned Opcode = N0.getOpcode(); if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG && @@ -17032,7 +17666,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { return SDValue(N, 0); // Match shuffles that can be converted to any_vector_extend_in_reg. - if (SDValue V = combineShuffleToVectorExtend(SVN, DAG, TLI, LegalOperations, LegalTypes)) + if (SDValue V = combineShuffleToVectorExtend(SVN, DAG, TLI, LegalOperations)) return V; // Combine "truncate_vector_in_reg" style shuffles. @@ -17050,7 +17684,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { // Attempt to combine a shuffle of 2 inputs of 'scalar sources' - // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR. - if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) + if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI)) return Res; @@ -17060,15 +17694,6 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() && N1.isUndef() && Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) { - - // Peek through the bitcast only if there is one user. - SDValue BC0 = N0; - while (BC0.getOpcode() == ISD::BITCAST) { - if (!BC0.hasOneUse()) - break; - BC0 = BC0.getOperand(0); - } - auto ScaleShuffleMask = [](ArrayRef<int> Mask, int Scale) { if (Scale == 1) return SmallVector<int, 8>(Mask.begin(), Mask.end()); @@ -17079,7 +17704,8 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { NewMask.push_back(M < 0 ? -1 : Scale * M + s); return NewMask; }; - + + SDValue BC0 = peekThroughOneUseBitcasts(N0); if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) { EVT SVT = VT.getScalarType(); EVT InnerVT = BC0->getValueType(0); @@ -17322,12 +17948,6 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { if (N1.isUndef()) return N0; - // For nested INSERT_SUBVECTORs, attempt to combine inner node first to allow - // us to pull BITCASTs from input to output. - if (N0.hasOneUse() && N0->getOpcode() == ISD::INSERT_SUBVECTOR) - if (SDValue NN0 = visitINSERT_SUBVECTOR(N0.getNode())) - return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, NN0, N1, N2); - // If this is an insert of an extracted vector into an undef vector, we can // just use the input to the extract. if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR && @@ -17375,6 +17995,14 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0), N1, N2); + // Eliminate an intermediate insert into an undef vector: + // insert_subvector undef, (insert_subvector undef, X, 0), N2 --> + // insert_subvector undef, X, N2 + if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR && + N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2))) + return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0, + N1.getOperand(1), N2); + if (!isa<ConstantSDNode>(N2)) return SDValue(); @@ -17410,6 +18038,10 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops); } + // Simplify source operands based on insertion. + if (SimplifyDemandedVectorElts(SDValue(N, 0))) + return SDValue(N, 0); + return SDValue(); } @@ -17447,7 +18079,7 @@ SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) { EVT VT = N->getValueType(0); SDValue LHS = N->getOperand(0); - SDValue RHS = peekThroughBitcast(N->getOperand(1)); + SDValue RHS = peekThroughBitcasts(N->getOperand(1)); SDLoc DL(N); // Make sure we're not running after operation legalization where it @@ -17677,31 +18309,64 @@ bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS, LLD->getBasePtr().getValueType())) return false; + // The loads must not depend on one another. + if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD)) + return false; + // Check that the select condition doesn't reach either load. If so, // folding this will induce a cycle into the DAG. If not, this is safe to // xform, so create a select of the addresses. + + SmallPtrSet<const SDNode *, 32> Visited; + SmallVector<const SDNode *, 16> Worklist; + + // Always fail if LLD and RLD are not independent. TheSelect is a + // predecessor to all Nodes in question so we need not search past it. + + Visited.insert(TheSelect); + Worklist.push_back(LLD); + Worklist.push_back(RLD); + + if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) || + SDNode::hasPredecessorHelper(RLD, Visited, Worklist)) + return false; + SDValue Addr; if (TheSelect->getOpcode() == ISD::SELECT) { + // We cannot do this optimization if any pair of {RLD, LLD} is a + // predecessor to {RLD, LLD, CondNode}. As we've already compared the + // Loads, we only need to check if CondNode is a successor to one of the + // loads. We can further avoid this if there's no use of their chain + // value. SDNode *CondNode = TheSelect->getOperand(0).getNode(); - if ((LLD->hasAnyUseOfValue(1) && LLD->isPredecessorOf(CondNode)) || - (RLD->hasAnyUseOfValue(1) && RLD->isPredecessorOf(CondNode))) - return false; - // The loads must not depend on one another. - if (LLD->isPredecessorOf(RLD) || - RLD->isPredecessorOf(LLD)) + Worklist.push_back(CondNode); + + if ((LLD->hasAnyUseOfValue(1) && + SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) || + (RLD->hasAnyUseOfValue(1) && + SDNode::hasPredecessorHelper(RLD, Visited, Worklist))) return false; + Addr = DAG.getSelect(SDLoc(TheSelect), LLD->getBasePtr().getValueType(), TheSelect->getOperand(0), LLD->getBasePtr(), RLD->getBasePtr()); } else { // Otherwise SELECT_CC + // We cannot do this optimization if any pair of {RLD, LLD} is a + // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared + // the Loads, we only need to check if CondLHS/CondRHS is a successor to + // one of the loads. We can further avoid this if there's no use of their + // chain value. + SDNode *CondLHS = TheSelect->getOperand(0).getNode(); SDNode *CondRHS = TheSelect->getOperand(1).getNode(); + Worklist.push_back(CondLHS); + Worklist.push_back(CondRHS); if ((LLD->hasAnyUseOfValue(1) && - (LLD->isPredecessorOf(CondLHS) || LLD->isPredecessorOf(CondRHS))) || + SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) || (RLD->hasAnyUseOfValue(1) && - (RLD->isPredecessorOf(CondLHS) || RLD->isPredecessorOf(CondRHS)))) + SDNode::hasPredecessorHelper(RLD, Visited, Worklist))) return false; Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect), @@ -17816,6 +18481,63 @@ SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, return DAG.getNode(ISD::AND, DL, AType, Shift, N2); } +/// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)" +/// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0 +/// in it. This may be a win when the constant is not otherwise available +/// because it replaces two constant pool loads with one. +SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset( + const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3, + ISD::CondCode CC) { + if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType().isFloatingPoint())) + return SDValue(); + + // If we are before legalize types, we want the other legalization to happen + // first (for example, to avoid messing with soft float). + auto *TV = dyn_cast<ConstantFPSDNode>(N2); + auto *FV = dyn_cast<ConstantFPSDNode>(N3); + EVT VT = N2.getValueType(); + if (!TV || !FV || !TLI.isTypeLegal(VT)) + return SDValue(); + + // If a constant can be materialized without loads, this does not make sense. + if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal || + TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0)) || + TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0))) + return SDValue(); + + // If both constants have multiple uses, then we won't need to do an extra + // load. The values are likely around in registers for other users. + if (!TV->hasOneUse() && !FV->hasOneUse()) + return SDValue(); + + Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()), + const_cast<ConstantFP*>(TV->getConstantFPValue()) }; + Type *FPTy = Elts[0]->getType(); + const DataLayout &TD = DAG.getDataLayout(); + + // Create a ConstantArray of the two constants. + Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts); + SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()), + TD.getPrefTypeAlignment(FPTy)); + unsigned Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlignment(); + + // Get offsets to the 0 and 1 elements of the array, so we can select between + // them. + SDValue Zero = DAG.getIntPtrConstant(0, DL); + unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType()); + SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV)); + SDValue Cond = + DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC); + AddToWorklist(Cond.getNode()); + SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero); + AddToWorklist(CstOffset.getNode()); + CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset); + AddToWorklist(CPIdx.getNode()); + return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx, + MachinePointerInfo::getConstantPool( + DAG.getMachineFunction()), Alignment); +} + /// Simplify an expression of the form (N0 cond N1) ? N2 : N3 /// where 'cond' is the comparison specified by CC. SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, @@ -17824,75 +18546,26 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, // (x ? y : y) -> y. if (N2 == N3) return N2; + EVT CmpOpVT = N0.getValueType(); EVT VT = N2.getValueType(); - ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getNode()); - ConstantSDNode *N2C = dyn_cast<ConstantSDNode>(N2.getNode()); + auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode()); + auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode()); + auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode()); - // Determine if the condition we're dealing with is constant - SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), - N0, N1, CC, DL, false); + // Determine if the condition we're dealing with is constant. + SDValue SCC = SimplifySetCC(getSetCCResultType(CmpOpVT), N0, N1, CC, DL, + false); if (SCC.getNode()) AddToWorklist(SCC.getNode()); - if (ConstantSDNode *SCCC = dyn_cast_or_null<ConstantSDNode>(SCC.getNode())) { + if (auto *SCCC = dyn_cast_or_null<ConstantSDNode>(SCC.getNode())) { // fold select_cc true, x, y -> x // fold select_cc false, x, y -> y return !SCCC->isNullValue() ? N2 : N3; } - // Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)" - // where "tmp" is a constant pool entry containing an array with 1.0 and 2.0 - // in it. This is a win when the constant is not otherwise available because - // it replaces two constant pool loads with one. We only do this if the FP - // type is known to be legal, because if it isn't, then we are before legalize - // types an we want the other legalization to happen first (e.g. to avoid - // messing with soft float) and if the ConstantFP is not legal, because if - // it is legal, we may not need to store the FP constant in a constant pool. - if (ConstantFPSDNode *TV = dyn_cast<ConstantFPSDNode>(N2)) - if (ConstantFPSDNode *FV = dyn_cast<ConstantFPSDNode>(N3)) { - if (TLI.isTypeLegal(N2.getValueType()) && - (TLI.getOperationAction(ISD::ConstantFP, N2.getValueType()) != - TargetLowering::Legal && - !TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0)) && - !TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0))) && - // If both constants have multiple uses, then we won't need to do an - // extra load, they are likely around in registers for other users. - (TV->hasOneUse() || FV->hasOneUse())) { - Constant *Elts[] = { - const_cast<ConstantFP*>(FV->getConstantFPValue()), - const_cast<ConstantFP*>(TV->getConstantFPValue()) - }; - Type *FPTy = Elts[0]->getType(); - const DataLayout &TD = DAG.getDataLayout(); - - // Create a ConstantArray of the two constants. - Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts); - SDValue CPIdx = - DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()), - TD.getPrefTypeAlignment(FPTy)); - unsigned Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlignment(); - - // Get the offsets to the 0 and 1 element of the array so that we can - // select between them. - SDValue Zero = DAG.getIntPtrConstant(0, DL); - unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType()); - SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV)); - - SDValue Cond = DAG.getSetCC(DL, - getSetCCResultType(N0.getValueType()), - N0, N1, CC); - AddToWorklist(Cond.getNode()); - SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), - Cond, One, Zero); - AddToWorklist(CstOffset.getNode()); - CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, - CstOffset); - AddToWorklist(CPIdx.getNode()); - return DAG.getLoad( - TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx, - MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), - Alignment); - } - } + if (SDValue V = + convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC)) + return V; if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC)) return V; @@ -17906,7 +18579,7 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND && N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) { SDValue AndLHS = N0->getOperand(0); - ConstantSDNode *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1)); + auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1)); if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) { // Shift the tested bit over the sign bit. const APInt &AndMask = ConstAndRHS->getAPIntValue(); @@ -17927,48 +18600,48 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, } // fold select C, 16, 0 -> shl C, 4 - if (N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2() && - TLI.getBooleanContents(N0.getValueType()) == - TargetLowering::ZeroOrOneBooleanContent) { + bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2(); + bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2(); + + if ((Fold || Swap) && + TLI.getBooleanContents(CmpOpVT) == + TargetLowering::ZeroOrOneBooleanContent && + (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) { + + if (Swap) { + CC = ISD::getSetCCInverse(CC, CmpOpVT.isInteger()); + std::swap(N2C, N3C); + } // If the caller doesn't want us to simplify this into a zext of a compare, // don't do it. if (NotExtCompare && N2C->isOne()) return SDValue(); - // Get a SetCC of the condition - // NOTE: Don't create a SETCC if it's not legal on this target. - if (!LegalOperations || - TLI.isOperationLegal(ISD::SETCC, N0.getValueType())) { - SDValue Temp, SCC; - // cast from setcc result type to select result type - if (LegalTypes) { - SCC = DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), - N0, N1, CC); - if (N2.getValueType().bitsLT(SCC.getValueType())) - Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), - N2.getValueType()); - else - Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), - N2.getValueType(), SCC); - } else { - SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC); - Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), - N2.getValueType(), SCC); - } + SDValue Temp, SCC; + // zext (setcc n0, n1) + if (LegalTypes) { + SCC = DAG.getSetCC(DL, getSetCCResultType(CmpOpVT), N0, N1, CC); + if (VT.bitsLT(SCC.getValueType())) + Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), VT); + else + Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC); + } else { + SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC); + Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC); + } - AddToWorklist(SCC.getNode()); - AddToWorklist(Temp.getNode()); + AddToWorklist(SCC.getNode()); + AddToWorklist(Temp.getNode()); - if (N2C->isOne()) - return Temp; + if (N2C->isOne()) + return Temp; - // shl setcc result by log2 n2c - return DAG.getNode( - ISD::SHL, DL, N2.getValueType(), Temp, - DAG.getConstant(N2C->getAPIntValue().logBase2(), SDLoc(Temp), - getShiftAmountTy(Temp.getValueType()))); - } + // shl setcc result by log2 n2c + return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp, + DAG.getConstant(N2C->getAPIntValue().logBase2(), + SDLoc(Temp), + getShiftAmountTy(Temp.getValueType()))); } // Check to see if this is an integer abs. @@ -17988,18 +18661,16 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, N0 == N3 && N2.getOpcode() == ISD::SUB && N0 == N2.getOperand(1)) SubC = dyn_cast<ConstantSDNode>(N2.getOperand(0)); - EVT XType = N0.getValueType(); - if (SubC && SubC->isNullValue() && XType.isInteger()) { + if (SubC && SubC->isNullValue() && CmpOpVT.isInteger()) { SDLoc DL(N0); - SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, - N0, - DAG.getConstant(XType.getSizeInBits() - 1, DL, - getShiftAmountTy(N0.getValueType()))); - SDValue Add = DAG.getNode(ISD::ADD, DL, - XType, N0, Shift); + SDValue Shift = DAG.getNode(ISD::SRA, DL, CmpOpVT, N0, + DAG.getConstant(CmpOpVT.getSizeInBits() - 1, + DL, + getShiftAmountTy(CmpOpVT))); + SDValue Add = DAG.getNode(ISD::ADD, DL, CmpOpVT, N0, Shift); AddToWorklist(Shift.getNode()); AddToWorklist(Add.getNode()); - return DAG.getNode(ISD::XOR, DL, XType, Add, Shift); + return DAG.getNode(ISD::XOR, DL, CmpOpVT, Add, Shift); } } @@ -18060,21 +18731,14 @@ SDValue DAGCombiner::BuildSDIV(SDNode *N) { if (DAG.getMachineFunction().getFunction().optForMinSize()) return SDValue(); - ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1)); - if (!C) - return SDValue(); - - // Avoid division by zero. - if (C->isNullValue()) - return SDValue(); - SmallVector<SDNode *, 8> Built; - SDValue S = - TLI.BuildSDIV(N, C->getAPIntValue(), DAG, LegalOperations, Built); + if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) { + for (SDNode *N : Built) + AddToWorklist(N); + return S; + } - for (SDNode *N : Built) - AddToWorklist(N); - return S; + return SDValue(); } /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a @@ -18089,11 +18753,13 @@ SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) { return SDValue(); SmallVector<SDNode *, 8> Built; - SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built); + if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) { + for (SDNode *N : Built) + AddToWorklist(N); + return S; + } - for (SDNode *N : Built) - AddToWorklist(N); - return S; + return SDValue(); } /// Given an ISD::UDIV node expressing a divide by constant, return a DAG @@ -18106,21 +18772,14 @@ SDValue DAGCombiner::BuildUDIV(SDNode *N) { if (DAG.getMachineFunction().getFunction().optForMinSize()) return SDValue(); - ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1)); - if (!C) - return SDValue(); - - // Avoid division by zero. - if (C->isNullValue()) - return SDValue(); - SmallVector<SDNode *, 8> Built; - SDValue S = - TLI.BuildUDIV(N, C->getAPIntValue(), DAG, LegalOperations, Built); + if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) { + for (SDNode *N : Built) + AddToWorklist(N); + return S; + } - for (SDNode *N : Built) - AddToWorklist(N); - return S; + return SDValue(); } /// Determines the LogBase2 value for a non-null input value using the @@ -18576,6 +19235,11 @@ SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) { return DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Aliases); } +// TODO: Replace with with std::monostate when we move to C++17. +struct UnitT { } Unit; +bool operator==(const UnitT &, const UnitT &) { return true; } +bool operator!=(const UnitT &, const UnitT &) { return false; } + // This function tries to collect a bunch of potentially interesting // nodes to improve the chains of, all at once. This might seem // redundant, as this function gets called when visiting every store @@ -18588,13 +19252,22 @@ SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) { // the nodes that will eventually be candidates, and then not be able // to go from a partially-merged state to the desired final // fully-merged state. -bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) { - if (OptLevel == CodeGenOpt::None) - return false; + +bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) { + SmallVector<StoreSDNode *, 8> ChainedStores; + StoreSDNode *STChain = St; + // Intervals records which offsets from BaseIndex have been covered. In + // the common case, every store writes to the immediately previous address + // space and thus merged with the previous interval at insertion time. + + using IMap = + llvm::IntervalMap<int64_t, UnitT, 8, IntervalMapHalfOpenInfo<int64_t>>; + IMap::Allocator A; + IMap Intervals(A); // This holds the base pointer, index, and the offset in bytes from the base // pointer. - BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG); + const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG); // We must have a base and an offset. if (!BasePtr.getBase().getNode()) @@ -18604,76 +19277,114 @@ bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) { if (BasePtr.getBase().isUndef()) return false; - SmallVector<StoreSDNode *, 8> ChainedStores; - ChainedStores.push_back(St); + // Add ST's interval. + Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit); - // Walk up the chain and look for nodes with offsets from the same - // base pointer. Stop when reaching an instruction with a different kind - // or instruction which has a different base pointer. - StoreSDNode *Index = St; - while (Index) { + while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) { // If the chain has more than one use, then we can't reorder the mem ops. - if (Index != St && !SDValue(Index, 0)->hasOneUse()) + if (!SDValue(Chain, 0)->hasOneUse()) break; - - if (Index->isVolatile() || Index->isIndexed()) + if (Chain->isVolatile() || Chain->isIndexed()) break; // Find the base pointer and offset for this memory node. - BaseIndexOffset Ptr = BaseIndexOffset::match(Index, DAG); - + const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG); // Check that the base pointer is the same as the original one. - if (!BasePtr.equalBaseIndex(Ptr, DAG)) + int64_t Offset; + if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset)) + break; + int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8; + // Make sure we don't overlap with other intervals by checking the ones to + // the left or right before inserting. + auto I = Intervals.find(Offset); + // If there's a next interval, we should end before it. + if (I != Intervals.end() && I.start() < (Offset + Length)) + break; + // If there's a previous interval, we should start after it. + if (I != Intervals.begin() && (--I).stop() <= Offset) break; + Intervals.insert(Offset, Offset + Length, Unit); - // Walk up the chain to find the next store node, ignoring any - // intermediate loads. Any other kind of node will halt the loop. - SDNode *NextInChain = Index->getChain().getNode(); - while (true) { - if (StoreSDNode *STn = dyn_cast<StoreSDNode>(NextInChain)) { - // We found a store node. Use it for the next iteration. - if (STn->isVolatile() || STn->isIndexed()) { - Index = nullptr; - break; - } - ChainedStores.push_back(STn); - Index = STn; - break; - } else if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(NextInChain)) { - NextInChain = Ldn->getChain().getNode(); - continue; - } else { - Index = nullptr; - break; - } - }// end while + ChainedStores.push_back(Chain); + STChain = Chain; } - // At this point, ChainedStores lists all of the Store nodes - // reachable by iterating up through chain nodes matching the above - // conditions. For each such store identified, try to find an - // earlier chain to attach the store to which won't violate the - // required ordering. - bool MadeChangeToSt = false; - SmallVector<std::pair<StoreSDNode *, SDValue>, 8> BetterChains; + // If we didn't find a chained store, exit. + if (ChainedStores.size() == 0) + return false; + + // Improve all chained stores (St and ChainedStores members) starting from + // where the store chain ended and return single TokenFactor. + SDValue NewChain = STChain->getChain(); + SmallVector<SDValue, 8> TFOps; + for (unsigned I = ChainedStores.size(); I;) { + StoreSDNode *S = ChainedStores[--I]; + SDValue BetterChain = FindBetterChain(S, NewChain); + S = cast<StoreSDNode>(DAG.UpdateNodeOperands( + S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3))); + TFOps.push_back(SDValue(S, 0)); + ChainedStores[I] = S; + } + + // Improve St's chain. Use a new node to avoid creating a loop from CombineTo. + SDValue BetterChain = FindBetterChain(St, NewChain); + SDValue NewST; + if (St->isTruncatingStore()) + NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(), + St->getBasePtr(), St->getMemoryVT(), + St->getMemOperand()); + else + NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(), + St->getBasePtr(), St->getMemOperand()); - for (StoreSDNode *ChainedStore : ChainedStores) { - SDValue Chain = ChainedStore->getChain(); - SDValue BetterChain = FindBetterChain(ChainedStore, Chain); + TFOps.push_back(NewST); - if (Chain != BetterChain) { - if (ChainedStore == St) - MadeChangeToSt = true; - BetterChains.push_back(std::make_pair(ChainedStore, BetterChain)); - } - } + // If we improved every element of TFOps, then we've lost the dependence on + // NewChain to successors of St and we need to add it back to TFOps. Do so at + // the beginning to keep relative order consistent with FindBetterChains. + auto hasImprovedChain = [&](SDValue ST) -> bool { + return ST->getOperand(0) != NewChain; + }; + bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain); + if (AddNewChain) + TFOps.insert(TFOps.begin(), NewChain); + + SDValue TF = DAG.getNode(ISD::TokenFactor, SDLoc(STChain), MVT::Other, TFOps); + CombineTo(St, TF); + + AddToWorklist(STChain); + // Add TF operands worklist in reverse order. + for (auto I = TF->getNumOperands(); I;) + AddToWorklist(TF->getOperand(--I).getNode()); + AddToWorklist(TF.getNode()); + return true; +} + +bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) { + if (OptLevel == CodeGenOpt::None) + return false; - // Do all replacements after finding the replacements to make to avoid making - // the chains more complicated by introducing new TokenFactors. - for (auto Replacement : BetterChains) - replaceStoreChain(Replacement.first, Replacement.second); + const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG); - return MadeChangeToSt; + // We must have a base and an offset. + if (!BasePtr.getBase().getNode()) + return false; + + // Do not handle stores to undef base pointers. + if (BasePtr.getBase().isUndef()) + return false; + + // Directly improve a chain of disjoint stores starting at St. + if (parallelizeChainedStores(St)) + return true; + + // Improve St's Chain.. + SDValue BetterChain = FindBetterChain(St, St->getChain()); + if (St->getChain() != BetterChain) { + replaceStoreChain(St, BetterChain); + return true; + } + return false; } /// This is the entry point for the file. |
