diff options
Diffstat (limited to 'lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r-- | lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 3533 |
1 files changed, 2164 insertions, 1369 deletions
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 3218dce8f575..7a99687757f8 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -36,7 +36,6 @@ #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineMemOperand.h" -#include "llvm/CodeGen/MachineValueType.h" #include "llvm/CodeGen/RuntimeLibcalls.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h" @@ -60,6 +59,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Support/MachineValueType.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" @@ -122,7 +122,7 @@ namespace { bool LegalTypes = false; bool ForCodeSize; - /// \brief Worklist of all of the nodes that need to be simplified. + /// Worklist of all of the nodes that need to be simplified. /// /// This must behave as a stack -- new nodes to process are pushed onto the /// back and when processing we pop off of the back. @@ -131,14 +131,14 @@ namespace { /// due to nodes being deleted from the underlying DAG. SmallVector<SDNode *, 64> Worklist; - /// \brief Mapping from an SDNode to its position on the worklist. + /// Mapping from an SDNode to its position on the worklist. /// /// This is used to find and remove nodes from the worklist (by nulling /// them) when they are deleted from the underlying DAG. It relies on /// stable indices of nodes within the worklist. DenseMap<SDNode *, unsigned> WorklistMap; - /// \brief Set of nodes which have been combined (at least once). + /// Set of nodes which have been combined (at least once). /// /// This is used to allow us to reliably add any operands of a DAG node /// which have not yet been combined to the worklist. @@ -232,14 +232,25 @@ namespace { return SimplifyDemandedBits(Op, Demanded); } + /// Check the specified vector node value to see if it can be simplified or + /// if things it uses can be simplified as it only uses some of the + /// elements. If so, return true. + bool SimplifyDemandedVectorElts(SDValue Op) { + unsigned NumElts = Op.getValueType().getVectorNumElements(); + APInt Demanded = APInt::getAllOnesValue(NumElts); + return SimplifyDemandedVectorElts(Op, Demanded); + } + bool SimplifyDemandedBits(SDValue Op, const APInt &Demanded); + bool SimplifyDemandedVectorElts(SDValue Op, const APInt &Demanded, + bool AssumeSingleUse = false); bool CombineToPreIndexedLoadStore(SDNode *N); bool CombineToPostIndexedLoadStore(SDNode *N); SDValue SplitIndexingFromLoad(LoadSDNode *LD); bool SliceUpLoad(SDNode *N); - /// \brief Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed + /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed /// load. /// /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced. @@ -258,10 +269,6 @@ namespace { SDValue PromoteExtend(SDValue Op); bool PromoteLoad(SDValue Op); - void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs, SDValue Trunc, - SDValue ExtLoad, const SDLoc &DL, - ISD::NodeType ExtType); - /// Call the node-specific routine that knows how to fold each /// particular type of node. If that doesn't do anything, try the /// target-specific DAG combines. @@ -292,7 +299,9 @@ namespace { SDValue visitMUL(SDNode *N); SDValue useDivRem(SDNode *N); SDValue visitSDIV(SDNode *N); + SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N); SDValue visitUDIV(SDNode *N); + SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N); SDValue visitREM(SDNode *N); SDValue visitMULHU(SDNode *N); SDValue visitMULHS(SDNode *N); @@ -302,9 +311,9 @@ namespace { SDValue visitUMULO(SDNode *N); SDValue visitIMINMAX(SDNode *N); SDValue visitAND(SDNode *N); - SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *LocReference); + SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N); SDValue visitOR(SDNode *N); - SDValue visitORLike(SDValue N0, SDValue N1, SDNode *LocReference); + SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N); SDValue visitXOR(SDNode *N); SDValue SimplifyVBinOp(SDNode *N); SDValue visitSHL(SDNode *N); @@ -323,7 +332,6 @@ namespace { SDValue visitVSELECT(SDNode *N); SDValue visitSELECT_CC(SDNode *N); SDValue visitSETCC(SDNode *N); - SDValue visitSETCCE(SDNode *N); SDValue visitSETCCCARRY(SDNode *N); SDValue visitSIGN_EXTEND(SDNode *N); SDValue visitZERO_EXTEND(SDNode *N); @@ -385,8 +393,8 @@ namespace { SDValue visitFMULForFMADistributiveCombine(SDNode *N); SDValue XformToShuffleWithZero(SDNode *N); - SDValue ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue LHS, - SDValue RHS); + SDValue ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, + SDValue N1); SDValue visitShiftByConstant(SDNode *N, ConstantSDNode *Amt); @@ -403,8 +411,11 @@ namespace { SDValue N2, SDValue N3, ISD::CondCode CC); SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1, const SDLoc &DL); + SDValue unfoldMaskedMerge(SDNode *N); + SDValue unfoldExtremeBitClearingToShifts(SDNode *N); SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond, - const SDLoc &DL, bool foldBooleans = true); + const SDLoc &DL, bool foldBooleans); + SDValue rebuildSetCC(SDValue N); bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS, SDValue &CC) const; @@ -414,20 +425,21 @@ namespace { unsigned HiOp); SDValue CombineConsecutiveLoads(SDNode *N, EVT VT); SDValue CombineExtLoad(SDNode *N); + SDValue CombineZExtLogicopShiftLoad(SDNode *N); SDValue combineRepeatedFPDivisors(SDNode *N); SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex); SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT); SDValue BuildSDIV(SDNode *N); SDValue BuildSDIVPow2(SDNode *N); SDValue BuildUDIV(SDNode *N); - SDValue BuildLogBase2(SDValue Op, const SDLoc &DL); + SDValue BuildLogBase2(SDValue V, const SDLoc &DL); SDValue BuildReciprocalEstimate(SDValue Op, SDNodeFlags Flags); SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags); SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags); SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip); - SDValue buildSqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations, + SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations, SDNodeFlags Flags, bool Reciprocal); - SDValue buildSqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations, + SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations, SDNodeFlags Flags, bool Reciprocal); SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, bool DemandHighBits = true); @@ -442,13 +454,14 @@ namespace { SDValue ReduceLoadOpStoreWidth(SDNode *N); SDValue splitMergedValStore(StoreSDNode *ST); SDValue TransformFPLoadStorePair(SDNode *N); + SDValue convertBuildVecZextToZext(SDNode *N); SDValue reduceBuildVecExtToExtBuildVec(SDNode *N); SDValue reduceBuildVecConvertToConvertBuildVec(SDNode *N); SDValue reduceBuildVecToShuffle(SDNode *N); SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N, ArrayRef<int> VectorMask, SDValue VecIn1, SDValue VecIn2, unsigned LeftIdx); - SDValue matchVSelectOpSizesWithSetCC(SDNode *N); + SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast); /// Walk up chain skipping non-aliasing memory nodes, /// looking for aliasing nodes and adding them to the Aliases vector. @@ -500,15 +513,15 @@ namespace { bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, EVT LoadResultTy, EVT &ExtVT); - /// Helper function to calculate whether the given Load can have its + /// Helper function to calculate whether the given Load/Store can have its /// width reduced to ExtVT. - bool isLegalNarrowLoad(LoadSDNode *LoadN, ISD::LoadExtType ExtType, - EVT &ExtVT, unsigned ShAmt = 0); + bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType, + EVT &MemVT, unsigned ShAmt = 0); /// Used by BackwardsPropagateMask to find suitable loads. bool SearchForAndLoads(SDNode *N, SmallPtrSetImpl<LoadSDNode*> &Loads, - SmallPtrSetImpl<SDNode*> &NodeWithConsts, - ConstantSDNode *Mask, SDNode *&UncombinedNode); + SmallPtrSetImpl<SDNode*> &NodesWithConsts, + ConstantSDNode *Mask, SDNode *&NodeToMask); /// Attempt to propagate a given AND node back to load leaves so that they /// can be combined into narrow loads. bool BackwardsPropagateMask(SDNode *N, SelectionDAG &DAG); @@ -530,23 +543,28 @@ namespace { /// This is a helper function for MergeConsecutiveStores. Stores /// that potentially may be merged with St are placed in - /// StoreNodes. + /// StoreNodes. RootNode is a chain predecessor to all store + /// candidates. void getStoreMergeCandidates(StoreSDNode *St, - SmallVectorImpl<MemOpLink> &StoreNodes); + SmallVectorImpl<MemOpLink> &StoreNodes, + SDNode *&Root); /// Helper function for MergeConsecutiveStores. Checks if /// candidate stores have indirect dependency through their - /// operands. \return True if safe to merge. + /// operands. RootNode is the predecessor to all stores calculated + /// by getStoreMergeCandidates and is used to prune the dependency check. + /// \return True if safe to merge. bool checkMergeStoreCandidatesForDependencies( - SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores); + SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores, + SDNode *RootNode); /// Merge consecutive store operations into a wide store. /// This optimization uses wide integers or vectors when possible. /// \return number of stores that were merged into a merged store (the /// affected nodes are stored as a prefix in \p StoreNodes). - bool MergeConsecutiveStores(StoreSDNode *N); + bool MergeConsecutiveStores(StoreSDNode *St); - /// \brief Try to transform a truncation where C is a constant: + /// Try to transform a truncation where C is a constant: /// (trunc (and X, C)) -> (and (trunc X), (trunc C)) /// /// \p N needs to be a truncation and its first operand an AND. Other @@ -554,6 +572,16 @@ namespace { /// single-use) and if missed an empty SDValue is returned. SDValue distributeTruncateThroughAnd(SDNode *N); + /// Helper function to determine whether the target supports operation + /// given by \p Opcode for type \p VT, that is, whether the operation + /// is legal or custom before legalizing operations, and whether is + /// legal (but not custom) after legalization. + bool hasOperation(unsigned Opcode, EVT VT) { + if (LegalOperations) + return TLI.isOperationLegal(Opcode, VT); + return TLI.isOperationLegalOrCustom(Opcode, VT); + } + public: /// Runs the dag combiner on all nodes in the work list void Run(CombineLevel AtLevel); @@ -564,11 +592,7 @@ namespace { /// legalization these can be huge. EVT getShiftAmountTy(EVT LHSTy) { assert(LHSTy.isInteger() && "Shift amount is not an integer type!"); - if (LHSTy.isVector()) - return LHSTy; - auto &DL = DAG.getDataLayout(); - return LegalTypes ? TLI.getScalarShiftAmountTy(DL, LHSTy) - : TLI.getPointerTy(DL); + return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout(), LegalTypes); } /// This method returns true if we are running before type legalization or @@ -582,6 +606,10 @@ namespace { EVT getSetCCResultType(EVT VT) const { return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); } + + void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs, + SDValue OrigLoad, SDValue ExtLoad, + ISD::NodeType ExtType); }; /// This class is a DAGUpdateListener that removes any deleted @@ -657,8 +685,13 @@ static char isNegatibleForFree(SDValue Op, bool LegalOperations, // fneg is removable even if it has multiple uses. if (Op.getOpcode() == ISD::FNEG) return 2; - // Don't allow anything with multiple uses. - if (!Op.hasOneUse()) return 0; + // Don't allow anything with multiple uses unless we know it is free. + EVT VT = Op.getValueType(); + const SDNodeFlags Flags = Op->getFlags(); + if (!Op.hasOneUse()) + if (!(Op.getOpcode() == ISD::FP_EXTEND && + TLI.isFPExtFree(VT, Op.getOperand(0).getValueType()))) + return 0; // Don't recurse exponentially. if (Depth > 6) return 0; @@ -671,17 +704,15 @@ static char isNegatibleForFree(SDValue Op, bool LegalOperations, // Don't invert constant FP values after legalization unless the target says // the negated constant is legal. - EVT VT = Op.getValueType(); return TLI.isOperationLegal(ISD::ConstantFP, VT) || TLI.isFPImmLegal(neg(cast<ConstantFPSDNode>(Op)->getValueAPF()), VT); } case ISD::FADD: - // FIXME: determine better conditions for this xform. - if (!Options->UnsafeFPMath) return 0; + if (!Options->UnsafeFPMath && !Flags.hasNoSignedZeros()) + return 0; // After operation legalization, it might not be legal to create new FSUBs. - if (LegalOperations && - !TLI.isOperationLegalOrCustom(ISD::FSUB, Op.getValueType())) + if (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) return 0; // fold (fneg (fadd A, B)) -> (fsub (fneg A), B) @@ -694,7 +725,7 @@ static char isNegatibleForFree(SDValue Op, bool LegalOperations, case ISD::FSUB: // We can't turn -(A-B) into B-A when we honor signed zeros. if (!Options->NoSignedZerosFPMath && - !Op.getNode()->getFlags().hasNoSignedZeros()) + !Flags.hasNoSignedZeros()) return 0; // fold (fneg (fsub A, B)) -> (fsub B, A) @@ -702,8 +733,6 @@ static char isNegatibleForFree(SDValue Op, bool LegalOperations, case ISD::FMUL: case ISD::FDIV: - if (Options->HonorSignDependentRoundingFPMath()) return 0; - // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) or (fmul X, (fneg Y)) if (char V = isNegatibleForFree(Op.getOperand(0), LegalOperations, TLI, Options, Depth + 1)) @@ -727,9 +756,6 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, // fneg is removable even if it has multiple uses. if (Op.getOpcode() == ISD::FNEG) return Op.getOperand(0); - // Don't allow anything with multiple uses. - assert(Op.hasOneUse() && "Unknown reuse!"); - assert(Depth <= 6 && "GetNegatedExpression doesn't match isNegatibleForFree"); const SDNodeFlags Flags = Op.getNode()->getFlags(); @@ -742,8 +768,7 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, return DAG.getConstantFP(V, SDLoc(Op), Op.getValueType()); } case ISD::FADD: - // FIXME: determine better conditions for this xform. - assert(Options.UnsafeFPMath); + assert(Options.UnsafeFPMath || Flags.hasNoSignedZeros()); // fold (fneg (fadd A, B)) -> (fsub (fneg A), B) if (isNegatibleForFree(Op.getOperand(0), LegalOperations, @@ -769,8 +794,6 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, case ISD::FMUL: case ISD::FDIV: - assert(!Options.HonorSignDependentRoundingFPMath()); - // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) if (isNegatibleForFree(Op.getOperand(0), LegalOperations, DAG.getTargetLoweringInfo(), &Options, Depth+1)) @@ -846,7 +869,13 @@ bool DAGCombiner::isOneUseSetCC(SDValue N) const { return false; } -// \brief Returns the SDNode if it is a constant float BuildVector +static SDValue peekThroughBitcast(SDValue V) { + while (V.getOpcode() == ISD::BITCAST) + V = V.getOperand(0); + return V; +} + +// Returns the SDNode if it is a constant float BuildVector // or constant float. static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) { if (isa<ConstantFPSDNode>(N)) @@ -880,6 +909,7 @@ static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) { // constant null integer (with no undefs). // Build vector implicit truncation is not an issue for null values. static bool isNullConstantOrNullSplatConstant(SDValue N) { + // TODO: may want to use peekThroughBitcast() here. if (ConstantSDNode *Splat = isConstOrConstSplat(N)) return Splat->isNullValue(); return false; @@ -889,6 +919,7 @@ static bool isNullConstantOrNullSplatConstant(SDValue N) { // constant integer of one (with no undefs). // Do not permit build vector implicit truncation. static bool isOneConstantOrOneSplatConstant(SDValue N) { + // TODO: may want to use peekThroughBitcast() here. unsigned BitWidth = N.getScalarValueSizeInBits(); if (ConstantSDNode *Splat = isConstOrConstSplat(N)) return Splat->isOne() && Splat->getAPIntValue().getBitWidth() == BitWidth; @@ -899,6 +930,7 @@ static bool isOneConstantOrOneSplatConstant(SDValue N) { // constant integer of all ones (with no undefs). // Do not permit build vector implicit truncation. static bool isAllOnesConstantOrAllOnesSplatConstant(SDValue N) { + N = peekThroughBitcast(N); unsigned BitWidth = N.getScalarValueSizeInBits(); if (ConstantSDNode *Splat = isConstOrConstSplat(N)) return Splat->isAllOnesValue() && @@ -913,56 +945,6 @@ static bool isAnyConstantBuildVector(const SDNode *N) { ISD::isBuildVectorOfConstantFPSDNodes(N); } -// Attempt to match a unary predicate against a scalar/splat constant or -// every element of a constant BUILD_VECTOR. -static bool matchUnaryPredicate(SDValue Op, - std::function<bool(ConstantSDNode *)> Match) { - if (auto *Cst = dyn_cast<ConstantSDNode>(Op)) - return Match(Cst); - - if (ISD::BUILD_VECTOR != Op.getOpcode()) - return false; - - EVT SVT = Op.getValueType().getScalarType(); - for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) { - auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(i)); - if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst)) - return false; - } - return true; -} - -// Attempt to match a binary predicate against a pair of scalar/splat constants -// or every element of a pair of constant BUILD_VECTORs. -static bool matchBinaryPredicate( - SDValue LHS, SDValue RHS, - std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match) { - if (LHS.getValueType() != RHS.getValueType()) - return false; - - if (auto *LHSCst = dyn_cast<ConstantSDNode>(LHS)) - if (auto *RHSCst = dyn_cast<ConstantSDNode>(RHS)) - return Match(LHSCst, RHSCst); - - if (ISD::BUILD_VECTOR != LHS.getOpcode() || - ISD::BUILD_VECTOR != RHS.getOpcode()) - return false; - - EVT SVT = LHS.getValueType().getScalarType(); - for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) { - auto *LHSCst = dyn_cast<ConstantSDNode>(LHS.getOperand(i)); - auto *RHSCst = dyn_cast<ConstantSDNode>(RHS.getOperand(i)); - if (!LHSCst || !RHSCst) - return false; - if (LHSCst->getValueType(0) != SVT || - LHSCst->getValueType(0) != RHSCst->getValueType(0)) - return false; - if (!Match(LHSCst, RHSCst)) - return false; - } - return true; -} - SDValue DAGCombiner::ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, SDValue N1) { EVT VT = N0.getValueType(); @@ -1013,11 +995,9 @@ SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo, bool AddTo) { assert(N->getNumValues() == NumTo && "Broken CombineTo call!"); ++NodesCombined; - DEBUG(dbgs() << "\nReplacing.1 "; - N->dump(&DAG); - dbgs() << "\nWith: "; - To[0].getNode()->dump(&DAG); - dbgs() << " and " << NumTo-1 << " other values\n"); + LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: "; + To[0].getNode()->dump(&DAG); + dbgs() << " and " << NumTo - 1 << " other values\n"); for (unsigned i = 0, e = NumTo; i != e; ++i) assert((!To[i].getNode() || N->getValueType(i) == To[i].getValueType()) && @@ -1074,11 +1054,33 @@ bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &Demanded) { // Replace the old value with the new one. ++NodesCombined; - DEBUG(dbgs() << "\nReplacing.2 "; - TLO.Old.getNode()->dump(&DAG); - dbgs() << "\nWith: "; - TLO.New.getNode()->dump(&DAG); - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG); + dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG); + dbgs() << '\n'); + + CommitTargetLoweringOpt(TLO); + return true; +} + +/// Check the specified vector node value to see if it can be simplified or +/// if things it uses can be simplified as it only uses some of the elements. +/// If so, return true. +bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op, const APInt &Demanded, + bool AssumeSingleUse) { + TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations); + APInt KnownUndef, KnownZero; + if (!TLI.SimplifyDemandedVectorElts(Op, Demanded, KnownUndef, KnownZero, TLO, + 0, AssumeSingleUse)) + return false; + + // Revisit the node. + AddToWorklist(Op.getNode()); + + // Replace the old value with the new one. + ++NodesCombined; + LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG); + dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG); + dbgs() << '\n'); CommitTargetLoweringOpt(TLO); return true; @@ -1089,11 +1091,8 @@ void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) { EVT VT = Load->getValueType(0); SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0)); - DEBUG(dbgs() << "\nReplacing.9 "; - Load->dump(&DAG); - dbgs() << "\nWith: "; - Trunc.getNode()->dump(&DAG); - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: "; + Trunc.getNode()->dump(&DAG); dbgs() << '\n'); WorklistRemover DeadNodes(*this); DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc); DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1)); @@ -1107,10 +1106,8 @@ SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) { if (ISD::isUNINDEXEDLoad(Op.getNode())) { LoadSDNode *LD = cast<LoadSDNode>(Op); EVT MemVT = LD->getMemoryVT(); - ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) - ? (TLI.isLoadExtLegal(ISD::ZEXTLOAD, PVT, MemVT) ? ISD::ZEXTLOAD - : ISD::EXTLOAD) - : LD->getExtensionType(); + ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD + : LD->getExtensionType(); Replace = true; return DAG.getExtLoad(ExtType, DL, PVT, LD->getChain(), LD->getBasePtr(), @@ -1194,7 +1191,7 @@ SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) { if (TLI.IsDesirableToPromoteOp(Op, PVT)) { assert(PVT != VT && "Don't know what type to promote to!"); - DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG)); + LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG)); bool Replace0 = false; SDValue N0 = Op.getOperand(0); @@ -1259,7 +1256,7 @@ SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) { if (TLI.IsDesirableToPromoteOp(Op, PVT)) { assert(PVT != VT && "Don't know what type to promote to!"); - DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG)); + LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG)); bool Replace = false; SDValue N0 = Op.getOperand(0); @@ -1311,8 +1308,7 @@ SDValue DAGCombiner::PromoteExtend(SDValue Op) { // fold (aext (aext x)) -> (aext x) // fold (aext (zext x)) -> (zext x) // fold (aext (sext x)) -> (sext x) - DEBUG(dbgs() << "\nPromoting "; - Op.getNode()->dump(&DAG)); + LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG)); return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0)); } return SDValue(); @@ -1345,20 +1341,15 @@ bool DAGCombiner::PromoteLoad(SDValue Op) { SDNode *N = Op.getNode(); LoadSDNode *LD = cast<LoadSDNode>(N); EVT MemVT = LD->getMemoryVT(); - ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) - ? (TLI.isLoadExtLegal(ISD::ZEXTLOAD, PVT, MemVT) ? ISD::ZEXTLOAD - : ISD::EXTLOAD) - : LD->getExtensionType(); + ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD + : LD->getExtensionType(); SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT, LD->getChain(), LD->getBasePtr(), MemVT, LD->getMemOperand()); SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD); - DEBUG(dbgs() << "\nPromoting "; - N->dump(&DAG); - dbgs() << "\nTo: "; - Result.getNode()->dump(&DAG); - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: "; + Result.getNode()->dump(&DAG); dbgs() << '\n'); WorklistRemover DeadNodes(*this); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1)); @@ -1369,7 +1360,7 @@ bool DAGCombiner::PromoteLoad(SDValue Op) { return false; } -/// \brief Recursively delete a node which has no uses and any operands for +/// Recursively delete a node which has no uses and any operands for /// which it is the only use. /// /// Note that this both deletes the nodes and removes them from the worklist. @@ -1453,7 +1444,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) { continue; } - DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG)); + LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG)); // Add any operands of the new node which have not yet been combined to the // worklist as well. Because the worklist uniques things already, this @@ -1481,8 +1472,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) { RV.getOpcode() != ISD::DELETED_NODE && "Node was deleted but visit returned new node!"); - DEBUG(dbgs() << " ... into: "; - RV.getNode()->dump(&DAG)); + LLVM_DEBUG(dbgs() << " ... into: "; RV.getNode()->dump(&DAG)); if (N->getNumValues() == RV.getNode()->getNumValues()) DAG.ReplaceAllUsesWith(N, RV.getNode()); @@ -1558,7 +1548,6 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::VSELECT: return visitVSELECT(N); case ISD::SELECT_CC: return visitSELECT_CC(N); case ISD::SETCC: return visitSETCC(N); - case ISD::SETCCE: return visitSETCCE(N); case ISD::SETCCCARRY: return visitSETCCCARRY(N); case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N); case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N); @@ -1708,6 +1697,10 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) { return N->getOperand(1); } + // Don't simplify token factors if optnone. + if (OptLevel == CodeGenOpt::None) + return SDValue(); + SmallVector<SDNode *, 8> TFs; // List of token factors to visit. SmallVector<SDValue, 8> Ops; // Ops for replacing token factor. SmallPtrSet<SDNode*, 16> SeenOps; @@ -1893,16 +1886,16 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) { BinOpcode == ISD::FDIV || BinOpcode == ISD::FREM) && "Unexpected binary operator"); - // Bail out if any constants are opaque because we can't constant fold those. - SDValue C1 = BO->getOperand(1); - if (!isConstantOrConstantVector(C1, true) && - !isConstantFPBuildVectorOrConstantFP(C1)) - return SDValue(); - // Don't do this unless the old select is going away. We want to eliminate the // binary operator, not replace a binop with a select. // TODO: Handle ISD::SELECT_CC. + unsigned SelOpNo = 0; SDValue Sel = BO->getOperand(0); + if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) { + SelOpNo = 1; + Sel = BO->getOperand(1); + } + if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) return SDValue(); @@ -1916,19 +1909,48 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) { !isConstantFPBuildVectorOrConstantFP(CF)) return SDValue(); + // Bail out if any constants are opaque because we can't constant fold those. + // The exception is "and" and "or" with either 0 or -1 in which case we can + // propagate non constant operands into select. I.e.: + // and (select Cond, 0, -1), X --> select Cond, 0, X + // or X, (select Cond, -1, 0) --> select Cond, -1, X + bool CanFoldNonConst = (BinOpcode == ISD::AND || BinOpcode == ISD::OR) && + (isNullConstantOrNullSplatConstant(CT) || + isAllOnesConstantOrAllOnesSplatConstant(CT)) && + (isNullConstantOrNullSplatConstant(CF) || + isAllOnesConstantOrAllOnesSplatConstant(CF)); + + SDValue CBO = BO->getOperand(SelOpNo ^ 1); + if (!CanFoldNonConst && + !isConstantOrConstantVector(CBO, true) && + !isConstantFPBuildVectorOrConstantFP(CBO)) + return SDValue(); + + EVT VT = Sel.getValueType(); + + // In case of shift value and shift amount may have different VT. For instance + // on x86 shift amount is i8 regardles of LHS type. Bail out if we have + // swapped operands and value types do not match. NB: x86 is fine if operands + // are not swapped with shift amount VT being not bigger than shifted value. + // TODO: that is possible to check for a shift operation, correct VTs and + // still perform optimization on x86 if needed. + if (SelOpNo && VT != CBO.getValueType()) + return SDValue(); + // We have a select-of-constants followed by a binary operator with a // constant. Eliminate the binop by pulling the constant math into the select. - // Example: add (select Cond, CT, CF), C1 --> select Cond, CT + C1, CF + C1 - EVT VT = Sel.getValueType(); + // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO SDLoc DL(Sel); - SDValue NewCT = DAG.getNode(BinOpcode, DL, VT, CT, C1); - if (!NewCT.isUndef() && + SDValue NewCT = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CT) + : DAG.getNode(BinOpcode, DL, VT, CT, CBO); + if (!CanFoldNonConst && !NewCT.isUndef() && !isConstantOrConstantVector(NewCT, true) && !isConstantFPBuildVectorOrConstantFP(NewCT)) return SDValue(); - SDValue NewCF = DAG.getNode(BinOpcode, DL, VT, CF, C1); - if (!NewCF.isUndef() && + SDValue NewCF = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CF) + : DAG.getNode(BinOpcode, DL, VT, CF, CBO); + if (!CanFoldNonConst && !NewCF.isUndef() && !isConstantOrConstantVector(NewCF, true) && !isConstantFPBuildVectorOrConstantFP(NewCF)) return SDValue(); @@ -1936,6 +1958,84 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) { return DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF); } +static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) { + assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) && + "Expecting add or sub"); + + // Match a constant operand and a zext operand for the math instruction: + // add Z, C + // sub C, Z + bool IsAdd = N->getOpcode() == ISD::ADD; + SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0); + SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1); + auto *CN = dyn_cast<ConstantSDNode>(C); + if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND) + return SDValue(); + + // Match the zext operand as a setcc of a boolean. + if (Z.getOperand(0).getOpcode() != ISD::SETCC || + Z.getOperand(0).getValueType() != MVT::i1) + return SDValue(); + + // Match the compare as: setcc (X & 1), 0, eq. + SDValue SetCC = Z.getOperand(0); + ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get(); + if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) || + SetCC.getOperand(0).getOpcode() != ISD::AND || + !isOneConstant(SetCC.getOperand(0).getOperand(1))) + return SDValue(); + + // We are adding/subtracting a constant and an inverted low bit. Turn that + // into a subtract/add of the low bit with incremented/decremented constant: + // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1)) + // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1)) + EVT VT = C.getValueType(); + SDLoc DL(N); + SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT); + SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) : + DAG.getConstant(CN->getAPIntValue() - 1, DL, VT); + return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit); +} + +/// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into +/// a shift and add with a different constant. +static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) { + assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) && + "Expecting add or sub"); + + // We need a constant operand for the add/sub, and the other operand is a + // logical shift right: add (srl), C or sub C, (srl). + bool IsAdd = N->getOpcode() == ISD::ADD; + SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0); + SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1); + ConstantSDNode *C = isConstOrConstSplat(ConstantOp); + if (!C || ShiftOp.getOpcode() != ISD::SRL) + return SDValue(); + + // The shift must be of a 'not' value. + // TODO: Use isBitwiseNot() if it works with vectors. + SDValue Not = ShiftOp.getOperand(0); + if (!Not.hasOneUse() || Not.getOpcode() != ISD::XOR || + !isAllOnesConstantOrAllOnesSplatConstant(Not.getOperand(1))) + return SDValue(); + + // The shift must be moving the sign bit to the least-significant-bit. + EVT VT = ShiftOp.getValueType(); + SDValue ShAmt = ShiftOp.getOperand(1); + ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt); + if (!ShAmtC || ShAmtC->getZExtValue() != VT.getScalarSizeInBits() - 1) + return SDValue(); + + // Eliminate the 'not' by adjusting the shift and add/sub constant: + // add (srl (not X), 31), C --> add (sra X, 31), (C + 1) + // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1) + SDLoc DL(N); + auto ShOpcode = IsAdd ? ISD::SRA : ISD::SRL; + SDValue NewShift = DAG.getNode(ShOpcode, DL, VT, Not.getOperand(0), ShAmt); + APInt NewC = IsAdd ? C->getAPIntValue() + 1 : C->getAPIntValue() - 1; + return DAG.getNode(ISD::ADD, DL, VT, NewShift, DAG.getConstant(NewC, DL, VT)); +} + SDValue DAGCombiner::visitADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -2067,6 +2167,12 @@ SDValue DAGCombiner::visitADD(SDNode *N) { DAG.getNode(ISD::ADD, SDLoc(N1), VT, N01, N11)); } + if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG)) + return V; + + if (SDValue V = foldAddSubOfSignBit(N, DAG)) + return V; + if (SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); @@ -2075,6 +2181,11 @@ SDValue DAGCombiner::visitADD(SDNode *N) { DAG.haveNoCommonBitsSet(N0, N1)) return DAG.getNode(ISD::OR, DL, VT, N0, N1); + // fold (add (xor a, -1), 1) -> (sub 0, a) + if (isBitwiseNot(N0) && isOneConstantOrOneSplatConstant(N1)) + return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), + N0.getOperand(0)); + if (SDValue Combined = visitADDLike(N0, N1, N)) return Combined; @@ -2210,6 +2321,38 @@ SDValue DAGCombiner::visitADDC(SDNode *N) { return SDValue(); } +static SDValue flipBoolean(SDValue V, const SDLoc &DL, EVT VT, + SelectionDAG &DAG, const TargetLowering &TLI) { + SDValue Cst; + switch (TLI.getBooleanContents(VT)) { + case TargetLowering::ZeroOrOneBooleanContent: + case TargetLowering::UndefinedBooleanContent: + Cst = DAG.getConstant(1, DL, VT); + break; + case TargetLowering::ZeroOrNegativeOneBooleanContent: + Cst = DAG.getConstant(-1, DL, VT); + break; + } + + return DAG.getNode(ISD::XOR, DL, VT, V, Cst); +} + +static bool isBooleanFlip(SDValue V, EVT VT, const TargetLowering &TLI) { + if (V.getOpcode() != ISD::XOR) return false; + ConstantSDNode *Const = dyn_cast<ConstantSDNode>(V.getOperand(1)); + if (!Const) return false; + + switch(TLI.getBooleanContents(VT)) { + case TargetLowering::ZeroOrOneBooleanContent: + return Const->isOne(); + case TargetLowering::ZeroOrNegativeOneBooleanContent: + return Const->isAllOnesValue(); + case TargetLowering::UndefinedBooleanContent: + return (Const->getAPIntValue() & 0x01) == 1; + } + llvm_unreachable("Unsupported boolean content"); +} + SDValue DAGCombiner::visitUADDO(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -2240,6 +2383,15 @@ SDValue DAGCombiner::visitUADDO(SDNode *N) { return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1), DAG.getConstant(0, DL, CarryVT)); + // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry. + if (isBitwiseNot(N0) && isOneConstantOrOneSplatConstant(N1)) { + SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(), + DAG.getConstant(0, DL, VT), + N0.getOperand(0)); + return CombineTo(N, Sub, + flipBoolean(Sub.getValue(1), DL, CarryVT, DAG, TLI)); + } + if (SDValue Combined = visitUADDOLike(N0, N1, N)) return Combined; @@ -2303,13 +2455,17 @@ SDValue DAGCombiner::visitADDCARRY(SDNode *N) { return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn); // fold (addcarry x, y, false) -> (uaddo x, y) - if (isNullConstant(CarryIn)) - return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1); + if (isNullConstant(CarryIn)) { + if (!LegalOperations || + TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0))) + return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1); + } + + EVT CarryVT = CarryIn.getValueType(); // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry. if (isNullConstant(N0) && isNullConstant(N1)) { EVT VT = N0.getValueType(); - EVT CarryVT = CarryIn.getValueType(); SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT); AddToWorklist(CarryExt.getNode()); return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt, @@ -2317,6 +2473,16 @@ SDValue DAGCombiner::visitADDCARRY(SDNode *N) { DAG.getConstant(0, DL, CarryVT)); } + // fold (addcarry (xor a, -1), 0, !b) -> (subcarry 0, a, b) and flip carry. + if (isBitwiseNot(N0) && isNullConstant(N1) && + isBooleanFlip(CarryIn, CarryVT, TLI)) { + SDValue Sub = DAG.getNode(ISD::SUBCARRY, DL, N->getVTList(), + DAG.getConstant(0, DL, N0.getValueType()), + N0.getOperand(0), CarryIn.getOperand(0)); + return CombineTo(N, Sub, + flipBoolean(Sub.getValue(1), DL, CarryVT, DAG, TLI)); + } + if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N)) return Combined; @@ -2458,6 +2624,11 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { if (isAllOnesConstantOrAllOnesSplatConstant(N0)) return DAG.getNode(ISD::XOR, DL, VT, N1, N0); + // fold (A - (0-B)) -> A+B + if (N1.getOpcode() == ISD::SUB && + isNullConstantOrNullSplatConstant(N1.getOperand(0))) + return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1)); + // fold A-(A-B) -> B if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0)) return N1.getOperand(1); @@ -2500,12 +2671,50 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N0.getOperand(1).getOperand(0)); + // fold (X - (-Y * Z)) -> (X + (Y * Z)) + if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) { + if (N1.getOperand(0).getOpcode() == ISD::SUB && + isNullConstantOrNullSplatConstant(N1.getOperand(0).getOperand(0))) { + SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, + N1.getOperand(0).getOperand(1), + N1.getOperand(1)); + return DAG.getNode(ISD::ADD, DL, VT, N0, Mul); + } + if (N1.getOperand(1).getOpcode() == ISD::SUB && + isNullConstantOrNullSplatConstant(N1.getOperand(1).getOperand(0))) { + SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, + N1.getOperand(0), + N1.getOperand(1).getOperand(1)); + return DAG.getNode(ISD::ADD, DL, VT, N0, Mul); + } + } + // If either operand of a sub is undef, the result is undef if (N0.isUndef()) return N0; if (N1.isUndef()) return N1; + if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG)) + return V; + + if (SDValue V = foldAddSubOfSignBit(N, DAG)) + return V; + + // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X) + if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) { + if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) { + SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1); + SDValue S0 = N1.getOperand(0); + if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0)) { + unsigned OpSizeInBits = VT.getScalarSizeInBits(); + if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1))) + if (C->getAPIntValue() == (OpSizeInBits - 1)) + return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0); + } + } + } + // If the relocation model supports it, consider symbol offsets. if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0)) if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) { @@ -2612,8 +2821,11 @@ SDValue DAGCombiner::visitSUBCARRY(SDNode *N) { SDValue CarryIn = N->getOperand(2); // fold (subcarry x, y, false) -> (usubo x, y) - if (isNullConstant(CarryIn)) - return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1); + if (isNullConstant(CarryIn)) { + if (!LegalOperations || + TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0))) + return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1); + } return SDValue(); } @@ -2689,11 +2901,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { (!VT.isVector() || Level <= AfterLegalizeVectorOps)) { SDLoc DL(N); SDValue LogBase2 = BuildLogBase2(N1, DL); - AddToWorklist(LogBase2.getNode()); - EVT ShiftVT = getShiftAmountTy(N0.getValueType()); SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT); - AddToWorklist(Trunc.getNode()); return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc); } // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c @@ -2816,9 +3025,10 @@ SDValue DAGCombiner::useDivRem(SDNode *Node) { SDValue Op1 = Node->getOperand(1); SDValue combined; for (SDNode::use_iterator UI = Op0.getNode()->use_begin(), - UE = Op0.getNode()->use_end(); UI != UE;) { - SDNode *User = *UI++; - if (User == Node || User->use_empty()) + UE = Op0.getNode()->use_end(); UI != UE; ++UI) { + SDNode *User = *UI; + if (User == Node || User->getOpcode() == ISD::DELETED_NODE || + User->use_empty()) continue; // Convert the other matching node(s), too; // otherwise, the DIVREM may get target-legalized into something @@ -2868,6 +3078,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); + EVT CCVT = getSetCCResultType(VT); // fold vector ops if (VT.isVector()) @@ -2887,6 +3098,11 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { // fold (sdiv X, -1) -> 0-X if (N1C && N1C->isAllOnesValue()) return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0); + // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0) + if (N1C && N1C->getAPIntValue().isMinSignedValue()) + return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ), + DAG.getConstant(1, DL, VT), + DAG.getConstant(0, DL, VT)); if (SDValue V = simplifyDivRem(N, DAG)) return V; @@ -2899,45 +3115,90 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1); + if (SDValue V = visitSDIVLike(N0, N1, N)) + return V; + + // sdiv, srem -> sdivrem + // If the divisor is constant, then return DIVREM only if isIntDivCheap() is + // true. Otherwise, we break the simplification logic in visitREM(). + AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); + if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr)) + if (SDValue DivRem = useDivRem(N)) + return DivRem; + + return SDValue(); +} + +SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + EVT CCVT = getSetCCResultType(VT); + unsigned BitWidth = VT.getScalarSizeInBits(); + + ConstantSDNode *N1C = isConstOrConstSplat(N1); + + // Helper for determining whether a value is a power-2 constant scalar or a + // vector of such elements. + auto IsPowerOfTwo = [](ConstantSDNode *C) { + if (C->isNullValue() || C->isOpaque()) + return false; + if (C->getAPIntValue().isPowerOf2()) + return true; + if ((-C->getAPIntValue()).isPowerOf2()) + return true; + return false; + }; + // fold (sdiv X, pow2) -> simple ops after legalize // FIXME: We check for the exact bit here because the generic lowering gives // better results in that case. The target-specific lowering should learn how // to handle exact sdivs efficiently. - if (N1C && !N1C->isNullValue() && !N1C->isOpaque() && - !N->getFlags().hasExact() && (N1C->getAPIntValue().isPowerOf2() || - (-N1C->getAPIntValue()).isPowerOf2())) { + if (!N->getFlags().hasExact() && + ISD::matchUnaryPredicate(N1C ? SDValue(N1C, 0) : N1, IsPowerOfTwo)) { // Target-specific implementation of sdiv x, pow2. if (SDValue Res = BuildSDIVPow2(N)) return Res; - unsigned lg2 = N1C->getAPIntValue().countTrailingZeros(); + // Create constants that are functions of the shift amount value. + EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType()); + SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy); + SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1); + C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy); + SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1); + if (!isConstantOrConstantVector(Inexact)) + return SDValue(); // Splat the sign bit into the register - SDValue SGN = - DAG.getNode(ISD::SRA, DL, VT, N0, - DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, - getShiftAmountTy(N0.getValueType()))); - AddToWorklist(SGN.getNode()); + SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0, + DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy)); + AddToWorklist(Sign.getNode()); // Add (N0 < 0) ? abs2 - 1 : 0; - SDValue SRL = - DAG.getNode(ISD::SRL, DL, VT, SGN, - DAG.getConstant(VT.getScalarSizeInBits() - lg2, DL, - getShiftAmountTy(SGN.getValueType()))); - SDValue ADD = DAG.getNode(ISD::ADD, DL, VT, N0, SRL); - AddToWorklist(SRL.getNode()); - AddToWorklist(ADD.getNode()); // Divide by pow2 - SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, ADD, - DAG.getConstant(lg2, DL, - getShiftAmountTy(ADD.getValueType()))); - - // If we're dividing by a positive value, we're done. Otherwise, we must - // negate the result. - if (N1C->getAPIntValue().isNonNegative()) - return SRA; - - AddToWorklist(SRA.getNode()); - return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), SRA); + SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact); + AddToWorklist(Srl.getNode()); + SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl); + AddToWorklist(Add.getNode()); + SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1); + AddToWorklist(Sra.getNode()); + + // Special case: (sdiv X, 1) -> X + // Special Case: (sdiv X, -1) -> 0-X + SDValue One = DAG.getConstant(1, DL, VT); + SDValue AllOnes = DAG.getAllOnesConstant(DL, VT); + SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ); + SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ); + SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes); + Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra); + + // If dividing by a positive value, we're done. Otherwise, the result must + // be negated. + SDValue Zero = DAG.getConstant(0, DL, VT); + SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra); + + // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding. + SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT); + SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra); + return Res; } // If integer divide is expensive and we satisfy the requirements, emit an @@ -2948,13 +3209,6 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { if (SDValue Op = BuildSDIV(N)) return Op; - // sdiv, srem -> sdivrem - // If the divisor is constant, then return DIVREM only if isIntDivCheap() is - // true. Otherwise, we break the simplification logic in visitREM(). - if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr)) - if (SDValue DivRem = useDivRem(N)) - return DivRem; - return SDValue(); } @@ -2962,6 +3216,7 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); + EVT CCVT = getSetCCResultType(VT); // fold vector ops if (VT.isVector()) @@ -2977,6 +3232,14 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, N0C, N1C)) return Folded; + // fold (udiv X, 1) -> X + if (N1C && N1C->isOne()) + return N0; + // fold (udiv X, -1) -> select(X == -1, 1, 0) + if (N1C && N1C->getAPIntValue().isAllOnesValue()) + return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ), + DAG.getConstant(1, DL, VT), + DAG.getConstant(0, DL, VT)); if (SDValue V = simplifyDivRem(N, DAG)) return V; @@ -2984,6 +3247,26 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; + if (SDValue V = visitUDIVLike(N0, N1, N)) + return V; + + // sdiv, srem -> sdivrem + // If the divisor is constant, then return DIVREM only if isIntDivCheap() is + // true. Otherwise, we break the simplification logic in visitREM(). + AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); + if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr)) + if (SDValue DivRem = useDivRem(N)) + return DivRem; + + return SDValue(); +} + +SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + + ConstantSDNode *N1C = isConstOrConstSplat(N1); + // fold (udiv x, (1 << c)) -> x >>u c if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) && DAG.isKnownToBeAPowerOfTwo(N1)) { @@ -3019,13 +3302,6 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { if (SDValue Op = BuildUDIV(N)) return Op; - // sdiv, srem -> sdivrem - // If the divisor is constant, then return DIVREM only if isIntDivCheap() is - // true. Otherwise, we break the simplification logic in visitREM(). - if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr)) - if (SDValue DivRem = useDivRem(N)) - return DivRem; - return SDValue(); } @@ -3035,6 +3311,8 @@ SDValue DAGCombiner::visitREM(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); + EVT CCVT = getSetCCResultType(VT); + bool isSigned = (Opcode == ISD::SREM); SDLoc DL(N); @@ -3044,6 +3322,10 @@ SDValue DAGCombiner::visitREM(SDNode *N) { if (N0C && N1C) if (SDValue Folded = DAG.FoldConstantArithmetic(Opcode, DL, VT, N0C, N1C)) return Folded; + // fold (urem X, -1) -> select(X == -1, 0, x) + if (!isSigned && N1C && N1C->getAPIntValue().isAllOnesValue()) + return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ), + DAG.getConstant(0, DL, VT), N0); if (SDValue V = simplifyDivRem(N, DAG)) return V; @@ -3077,22 +3359,19 @@ SDValue DAGCombiner::visitREM(SDNode *N) { // If X/C can be simplified by the division-by-constant logic, lower // X%C to the equivalent of X-X/C*C. - // To avoid mangling nodes, this simplification requires that the combine() - // call for the speculative DIV must not cause a DIVREM conversion. We guard - // against this by skipping the simplification if isIntDivCheap(). When - // div is not cheap, combine will not return a DIVREM. Regardless, - // checking cheapness here makes sense since the simplification results in - // fatter code. - if (N1C && !N1C->isNullValue() && !TLI.isIntDivCheap(VT, Attr)) { - unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV; - SDValue Div = DAG.getNode(DivOpcode, DL, VT, N0, N1); - AddToWorklist(Div.getNode()); - SDValue OptimizedDiv = combine(Div.getNode()); - if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != Div.getNode()) { - assert((OptimizedDiv.getOpcode() != ISD::UDIVREM) && - (OptimizedDiv.getOpcode() != ISD::SDIVREM)); + // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the + // speculative DIV must not cause a DIVREM conversion. We guard against this + // by skipping the simplification if isIntDivCheap(). When div is not cheap, + // combine will not return a DIVREM. Regardless, checking cheapness here + // makes sense since the simplification results in fatter code. + if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) { + SDValue OptimizedDiv = + isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N); + if (OptimizedDiv.getNode() && OptimizedDiv.getOpcode() != ISD::UDIVREM && + OptimizedDiv.getOpcode() != ISD::SDIVREM) { SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1); SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul); + AddToWorklist(OptimizedDiv.getNode()); AddToWorklist(Mul.getNode()); return Sub; } @@ -3350,6 +3629,25 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) { !DAG.isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0); + // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX. + // Only do this if the current op isn't legal and the flipped is. + unsigned Opcode = N->getOpcode(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!TLI.isOperationLegal(Opcode, VT) && + (N0.isUndef() || DAG.SignBitIsZero(N0)) && + (N1.isUndef() || DAG.SignBitIsZero(N1))) { + unsigned AltOpcode; + switch (Opcode) { + case ISD::SMIN: AltOpcode = ISD::UMIN; break; + case ISD::SMAX: AltOpcode = ISD::UMAX; break; + case ISD::UMIN: AltOpcode = ISD::SMIN; break; + case ISD::UMAX: AltOpcode = ISD::SMAX; break; + default: llvm_unreachable("Unknown MINMAX opcode"); + } + if (TLI.isOperationLegal(AltOpcode, VT)) + return DAG.getNode(AltOpcode, SDLoc(N), VT, N0, N1); + } + return SDValue(); } @@ -3469,9 +3767,9 @@ SDValue DAGCombiner::SimplifyBinOpWithSameOpcodeHands(SDNode *N) { ShOp = SDValue(); } - // (AND (shuf (A, C), shuf (B, C)) -> shuf (AND (A, B), C) - // (OR (shuf (A, C), shuf (B, C)) -> shuf (OR (A, B), C) - // (XOR (shuf (A, C), shuf (B, C)) -> shuf (XOR (A, B), V_0) + // (AND (shuf (A, C), shuf (B, C))) -> shuf (AND (A, B), C) + // (OR (shuf (A, C), shuf (B, C))) -> shuf (OR (A, B), C) + // (XOR (shuf (A, C), shuf (B, C))) -> shuf (XOR (A, B), V_0) if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) { SDValue NewNode = DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0->getOperand(0), N1->getOperand(0)); @@ -3490,9 +3788,9 @@ SDValue DAGCombiner::SimplifyBinOpWithSameOpcodeHands(SDNode *N) { ShOp = SDValue(); } - // (AND (shuf (C, A), shuf (C, B)) -> shuf (C, AND (A, B)) - // (OR (shuf (C, A), shuf (C, B)) -> shuf (C, OR (A, B)) - // (XOR (shuf (C, A), shuf (C, B)) -> shuf (V_0, XOR (A, B)) + // (AND (shuf (C, A), shuf (C, B))) -> shuf (C, AND (A, B)) + // (OR (shuf (C, A), shuf (C, B))) -> shuf (C, OR (A, B)) + // (XOR (shuf (C, A), shuf (C, B))) -> shuf (V_0, XOR (A, B)) if (N0->getOperand(0) == N1->getOperand(0) && ShOp.getNode()) { SDValue NewNode = DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0->getOperand(1), N1->getOperand(1)); @@ -3525,7 +3823,7 @@ SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1, // operations on the left and right operands, so those types must match. EVT VT = N0.getValueType(); EVT OpVT = LL.getValueType(); - if (LegalOperations || VT != MVT::i1) + if (LegalOperations || VT.getScalarType() != MVT::i1) if (VT != getSetCCResultType(OpVT)) return SDValue(); if (OpVT != RL.getValueType()) @@ -3762,53 +4060,78 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, return true; } -bool DAGCombiner::isLegalNarrowLoad(LoadSDNode *LoadN, ISD::LoadExtType ExtType, - EVT &ExtVT, unsigned ShAmt) { - // Don't transform one with multiple uses, this would require adding a new - // load. - if (!SDValue(LoadN, 0).hasOneUse()) +bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST, + ISD::LoadExtType ExtType, EVT &MemVT, + unsigned ShAmt) { + if (!LDST) return false; - - if (LegalOperations && - !TLI.isLoadExtLegal(ExtType, LoadN->getValueType(0), ExtVT)) + // Only allow byte offsets. + if (ShAmt % 8) return false; // Do not generate loads of non-round integer types since these can // be expensive (and would be wrong if the type is not byte sized). - if (!ExtVT.isRound()) + if (!MemVT.isRound()) return false; // Don't change the width of a volatile load. - if (LoadN->isVolatile()) + if (LDST->isVolatile()) return false; // Verify that we are actually reducing a load width here. - if (LoadN->getMemoryVT().getSizeInBits() < ExtVT.getSizeInBits()) - return false; - - // For the transform to be legal, the load must produce only two values - // (the value loaded and the chain). Don't transform a pre-increment - // load, for example, which produces an extra value. Otherwise the - // transformation is not equivalent, and the downstream logic to replace - // uses gets things wrong. - if (LoadN->getNumValues() > 2) + if (LDST->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits()) return false; - // If the load that we're shrinking is an extload and we're not just - // discarding the extension we can't simply shrink the load. Bail. - // TODO: It would be possible to merge the extensions in some cases. - if (LoadN->getExtensionType() != ISD::NON_EXTLOAD && - LoadN->getMemoryVT().getSizeInBits() < ExtVT.getSizeInBits() + ShAmt) - return false; - - if (!TLI.shouldReduceLoadWidth(LoadN, ExtType, ExtVT)) + // Ensure that this isn't going to produce an unsupported unaligned access. + if (ShAmt && + !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT, + LDST->getAddressSpace(), ShAmt / 8)) return false; // It's not possible to generate a constant of extended or untyped type. - EVT PtrType = LoadN->getOperand(1).getValueType(); + EVT PtrType = LDST->getBasePtr().getValueType(); if (PtrType == MVT::Untyped || PtrType.isExtended()) return false; + if (isa<LoadSDNode>(LDST)) { + LoadSDNode *Load = cast<LoadSDNode>(LDST); + // Don't transform one with multiple uses, this would require adding a new + // load. + if (!SDValue(Load, 0).hasOneUse()) + return false; + + if (LegalOperations && + !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT)) + return false; + + // For the transform to be legal, the load must produce only two values + // (the value loaded and the chain). Don't transform a pre-increment + // load, for example, which produces an extra value. Otherwise the + // transformation is not equivalent, and the downstream logic to replace + // uses gets things wrong. + if (Load->getNumValues() > 2) + return false; + + // If the load that we're shrinking is an extload and we're not just + // discarding the extension we can't simply shrink the load. Bail. + // TODO: It would be possible to merge the extensions in some cases. + if (Load->getExtensionType() != ISD::NON_EXTLOAD && + Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt) + return false; + + if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT)) + return false; + } else { + assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode"); + StoreSDNode *Store = cast<StoreSDNode>(LDST); + // Can't write outside the original store + if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt) + return false; + + if (LegalOperations && + !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT)) + return false; + } return true; } @@ -3841,16 +4164,22 @@ bool DAGCombiner::SearchForAndLoads(SDNode *N, auto *Load = cast<LoadSDNode>(Op); EVT ExtVT; if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) && - isLegalNarrowLoad(Load, ISD::ZEXTLOAD, ExtVT)) { - // Only add this load if we can make it more narrow. - if (ExtVT.bitsLT(Load->getMemoryVT())) + isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) { + + // ZEXTLOAD is already small enough. + if (Load->getExtensionType() == ISD::ZEXTLOAD && + ExtVT.bitsGE(Load->getMemoryVT())) + continue; + + // Use LE to convert equal sized loads to zext. + if (ExtVT.bitsLE(Load->getMemoryVT())) Loads.insert(Load); + continue; } return false; } case ISD::ZERO_EXTEND: - case ISD::ANY_EXTEND: case ISD::AssertZext: { unsigned ActiveBits = Mask->getAPIntValue().countTrailingOnes(); EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits); @@ -3876,7 +4205,23 @@ bool DAGCombiner::SearchForAndLoads(SDNode *N, // Allow one node which will masked along with any loads found. if (NodeToMask) return false; + + // Also ensure that the node to be masked only produces one data result. NodeToMask = Op.getNode(); + if (NodeToMask->getNumValues() > 1) { + bool HasValue = false; + for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) { + MVT VT = SDValue(NodeToMask, i).getSimpleValueType(); + if (VT != MVT::Glue && VT != MVT::Other) { + if (HasValue) { + NodeToMask = nullptr; + return false; + } + HasValue = true; + } + } + assert(HasValue && "Node to be masked has no data result?"); + } } return true; } @@ -3900,33 +4245,44 @@ bool DAGCombiner::BackwardsPropagateMask(SDNode *N, SelectionDAG &DAG) { if (Loads.size() == 0) return false; + LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump()); SDValue MaskOp = N->getOperand(1); // If it exists, fixup the single node we allow in the tree that needs // masking. if (FixupNode) { + LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump()); SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode), FixupNode->getValueType(0), SDValue(FixupNode, 0), MaskOp); DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And); - DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), - MaskOp); + if (And.getOpcode() == ISD ::AND) + DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp); } // Narrow any constants that need it. for (auto *LogicN : NodesWithConsts) { - auto *C = cast<ConstantSDNode>(LogicN->getOperand(1)); - SDValue And = DAG.getNode(ISD::AND, SDLoc(C), C->getValueType(0), - SDValue(C, 0), MaskOp); - DAG.UpdateNodeOperands(LogicN, LogicN->getOperand(0), And); + SDValue Op0 = LogicN->getOperand(0); + SDValue Op1 = LogicN->getOperand(1); + + if (isa<ConstantSDNode>(Op0)) + std::swap(Op0, Op1); + + SDValue And = DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(), + Op1, MaskOp); + + DAG.UpdateNodeOperands(LogicN, Op0, And); } // Create narrow loads. for (auto *Load : Loads) { + LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump()); SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0), SDValue(Load, 0), MaskOp); DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And); - DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp); + if (And.getOpcode() == ISD ::AND) + And = SDValue( + DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0); SDValue NewLoad = ReduceLoadWidth(And.getNode()); assert(NewLoad && "Shouldn't be masking the load if it can't be narrowed"); @@ -3938,6 +4294,60 @@ bool DAGCombiner::BackwardsPropagateMask(SDNode *N, SelectionDAG &DAG) { return false; } +// Unfold +// x & (-1 'logical shift' y) +// To +// (x 'opposite logical shift' y) 'logical shift' y +// if it is better for performance. +SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) { + assert(N->getOpcode() == ISD::AND); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // Do we actually prefer shifts over mask? + if (!TLI.preferShiftsToClearExtremeBits(N0)) + return SDValue(); + + // Try to match (-1 '[outer] logical shift' y) + unsigned OuterShift; + unsigned InnerShift; // The opposite direction to the OuterShift. + SDValue Y; // Shift amount. + auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool { + if (!M.hasOneUse()) + return false; + OuterShift = M->getOpcode(); + if (OuterShift == ISD::SHL) + InnerShift = ISD::SRL; + else if (OuterShift == ISD::SRL) + InnerShift = ISD::SHL; + else + return false; + if (!isAllOnesConstant(M->getOperand(0))) + return false; + Y = M->getOperand(1); + return true; + }; + + SDValue X; + if (matchMask(N1)) + X = N0; + else if (matchMask(N0)) + X = N1; + else + return SDValue(); + + SDLoc DL(N); + EVT VT = N->getValueType(0); + + // tmp = x 'opposite logical shift' y + SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y); + // ret = tmp 'logical shift' y + SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y); + + return T1; +} + SDValue DAGCombiner::visitAND(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -4004,7 +4414,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue()); }; if (N0.getOpcode() == ISD::OR && - matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset)) + ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset)) return N1; // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits. if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) { @@ -4235,6 +4645,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) { return BSwap; } + if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N)) + return Shifts; + return SDValue(); } @@ -4261,7 +4674,10 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, if (!N0.getNode()->hasOneUse()) return SDValue(); ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1)); - if (!N01C || N01C->getZExtValue() != 0xFF00) + // Also handle 0xffff since the LHS is guaranteed to have zeros there. + // This is needed for X86. + if (!N01C || (N01C->getZExtValue() != 0xFF00 && + N01C->getZExtValue() != 0xFFFF)) return SDValue(); N0 = N0.getOperand(0); LookPassAnd0 = true; @@ -4308,7 +4724,10 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, if (!N10.getNode()->hasOneUse()) return SDValue(); ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1)); - if (!N101C || N101C->getZExtValue() != 0xFF00) + // Also allow 0xFFFF since the bits will be shifted out. This is needed + // for X86. + if (!N101C || (N101C->getZExtValue() != 0xFF00 && + N101C->getZExtValue() != 0xFFFF)) return SDValue(); N10 = N10.getOperand(0); LookPassAnd1 = true; @@ -4379,6 +4798,14 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) { return false; case 0xFF: MaskByteOffset = 0; break; case 0xFF00: MaskByteOffset = 1; break; + case 0xFFFF: + // In case demanded bits didn't clear the bits that will be shifted out. + // This is needed for X86. + if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) { + MaskByteOffset = 1; + break; + } + return false; case 0xFF0000: MaskByteOffset = 2; break; case 0xFF000000: MaskByteOffset = 3; break; } @@ -4693,7 +5120,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) { return LHS->getAPIntValue().intersects(RHS->getAPIntValue()); }; if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() && - matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect)) { + ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect)) { if (SDValue COR = DAG.FoldConstantArithmetic( ISD::OR, SDLoc(N1), VT, N1.getNode(), N0.getOperand(1).getNode())) { SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1); @@ -4749,7 +5176,8 @@ bool DAGCombiner::MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) { // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate // in direction shift1 by Neg. The range [0, EltSize) means that we only need // to consider shift amounts with defined behavior. -static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize) { +static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize, + SelectionDAG &DAG) { // If EltSize is a power of 2 then: // // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1) @@ -4784,9 +5212,13 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize) { unsigned MaskLoBits = 0; if (Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) { if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) { - if (NegC->getAPIntValue() == EltSize - 1) { + KnownBits Known; + DAG.computeKnownBits(Neg.getOperand(0), Known); + unsigned Bits = Log2_64(EltSize); + if (NegC->getAPIntValue().getActiveBits() <= Bits && + ((NegC->getAPIntValue() | Known.Zero).countTrailingOnes() >= Bits)) { Neg = Neg.getOperand(0); - MaskLoBits = Log2_64(EltSize); + MaskLoBits = Bits; } } } @@ -4801,10 +5233,16 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize) { // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with // Pos'. The truncation is redundant for the purpose of the equality. - if (MaskLoBits && Pos.getOpcode() == ISD::AND) - if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) - if (PosC->getAPIntValue() == EltSize - 1) + if (MaskLoBits && Pos.getOpcode() == ISD::AND) { + if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) { + KnownBits Known; + DAG.computeKnownBits(Pos.getOperand(0), Known); + if (PosC->getAPIntValue().getActiveBits() <= MaskLoBits && + ((PosC->getAPIntValue() | Known.Zero).countTrailingOnes() >= + MaskLoBits)) Pos = Pos.getOperand(0); + } + } // The condition we need is now: // @@ -4860,7 +5298,7 @@ SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, // (srl x, (*ext y))) -> // (rotr x, y) or (rotl x, (sub 32, y)) EVT VT = Shifted.getValueType(); - if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits())) { + if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG)) { bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT); return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted, HasPos ? Pos : Neg).getNode(); @@ -4878,8 +5316,8 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { if (!TLI.isTypeLegal(VT)) return nullptr; // The target must have at least one rotate flavor. - bool HasROTL = TLI.isOperationLegalOrCustom(ISD::ROTL, VT); - bool HasROTR = TLI.isOperationLegalOrCustom(ISD::ROTR, VT); + bool HasROTL = hasOperation(ISD::ROTL, VT); + bool HasROTR = hasOperation(ISD::ROTR, VT); if (!HasROTL && !HasROTR) return nullptr; // Check for truncated rotate. @@ -4928,7 +5366,7 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { ConstantSDNode *RHS) { return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits; }; - if (matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) { + if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) { SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt); @@ -5185,7 +5623,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { Optional<BaseIndexOffset> Base; SDValue Chain; - SmallSet<LoadSDNode *, 8> Loads; + SmallPtrSet<LoadSDNode *, 8> Loads; Optional<ByteProvider> FirstByteProvider; int64_t FirstOffset = INT64_MAX; @@ -5210,7 +5648,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { return SDValue(); // Loads must share the same base address - BaseIndexOffset Ptr = BaseIndexOffset::match(L->getBasePtr(), DAG); + BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG); int64_t ByteOffsetFromBase = 0; if (!Base) Base = Ptr; @@ -5284,6 +5722,88 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { return NeedsBswap ? DAG.getNode(ISD::BSWAP, SDLoc(N), VT, NewLoad) : NewLoad; } +// If the target has andn, bsl, or a similar bit-select instruction, +// we want to unfold masked merge, with canonical pattern of: +// | A | |B| +// ((x ^ y) & m) ^ y +// | D | +// Into: +// (x & m) | (y & ~m) +// If y is a constant, and the 'andn' does not work with immediates, +// we unfold into a different pattern: +// ~(~x & m) & (m | y) +// NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at +// the very least that breaks andnpd / andnps patterns, and because those +// patterns are simplified in IR and shouldn't be created in the DAG +SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) { + assert(N->getOpcode() == ISD::XOR); + + // Don't touch 'not' (i.e. where y = -1). + if (isAllOnesConstantOrAllOnesSplatConstant(N->getOperand(1))) + return SDValue(); + + EVT VT = N->getValueType(0); + + // There are 3 commutable operators in the pattern, + // so we have to deal with 8 possible variants of the basic pattern. + SDValue X, Y, M; + auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) { + if (And.getOpcode() != ISD::AND || !And.hasOneUse()) + return false; + SDValue Xor = And.getOperand(XorIdx); + if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse()) + return false; + SDValue Xor0 = Xor.getOperand(0); + SDValue Xor1 = Xor.getOperand(1); + // Don't touch 'not' (i.e. where y = -1). + if (isAllOnesConstantOrAllOnesSplatConstant(Xor1)) + return false; + if (Other == Xor0) + std::swap(Xor0, Xor1); + if (Other != Xor1) + return false; + X = Xor0; + Y = Xor1; + M = And.getOperand(XorIdx ? 0 : 1); + return true; + }; + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) && + !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0)) + return SDValue(); + + // Don't do anything if the mask is constant. This should not be reachable. + // InstCombine should have already unfolded this pattern, and DAGCombiner + // probably shouldn't produce it, too. + if (isa<ConstantSDNode>(M.getNode())) + return SDValue(); + + // We can transform if the target has AndNot + if (!TLI.hasAndNot(M)) + return SDValue(); + + SDLoc DL(N); + + // If Y is a constant, check that 'andn' works with immediates. + if (!TLI.hasAndNot(Y)) { + assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable."); + // If not, we need to do a bit more work to make sure andn is still used. + SDValue NotX = DAG.getNOT(DL, X, VT); + SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M); + SDValue NotLHS = DAG.getNOT(DL, LHS, VT); + SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y); + return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS); + } + + SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M); + SDValue NotM = DAG.getNOT(DL, M, VT); + SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM); + + return DAG.getNode(ISD::OR, DL, VT, LHS, RHS); +} + SDValue DAGCombiner::visitXOR(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -5363,7 +5883,7 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { } // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc - if (isOneConstant(N1) && VT == MVT::i1 && + if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() && (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND)) { SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1); if (isOneUseSetCC(RHS) || isOneUseSetCC(LHS)) { @@ -5375,7 +5895,7 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { } } // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants - if (isAllOnesConstant(N1) && + if (isAllOnesConstant(N1) && N0.hasOneUse() && (N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::AND)) { SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1); if (isa<ConstantSDNode>(RHS) || isa<ConstantSDNode>(LHS)) { @@ -5396,13 +5916,19 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { } // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X) - unsigned OpSizeInBits = VT.getScalarSizeInBits(); - if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1 && - N1.getOpcode() == ISD::SRA && N1.getOperand(0) == N0.getOperand(0) && - TLI.isOperationLegalOrCustom(ISD::ABS, VT)) { - if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1))) - if (C->getAPIntValue() == (OpSizeInBits - 1)) - return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0.getOperand(0)); + if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) { + SDValue A = N0.getOpcode() == ISD::ADD ? N0 : N1; + SDValue S = N0.getOpcode() == ISD::SRA ? N0 : N1; + if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) { + SDValue A0 = A.getOperand(0), A1 = A.getOperand(1); + SDValue S0 = S.getOperand(0); + if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0)) { + unsigned OpSizeInBits = VT.getScalarSizeInBits(); + if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1))) + if (C->getAPIntValue() == (OpSizeInBits - 1)) + return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0); + } + } } // fold (xor x, x) -> 0 @@ -5439,6 +5965,10 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N)) return Tmp; + // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable + if (SDValue MM = unfoldMaskedMerge(N)) + return MM; + // Simplify the expression using non-local knowledge. if (SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); @@ -5641,7 +6171,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) { return Val->getAPIntValue().uge(OpSizeInBits); }; - if (matchUnaryPredicate(N1, MatchShiftTooBig)) + if (ISD::matchUnaryPredicate(N1, MatchShiftTooBig)) return DAG.getUNDEF(VT); // fold (shl x, 0) -> x if (N1C && N1C->isNullValue()) @@ -5676,7 +6206,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); return (c1 + c2).uge(OpSizeInBits); }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) return DAG.getConstant(0, SDLoc(N), VT); auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS, @@ -5686,7 +6216,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); return (c1 + c2).ult(OpSizeInBits); }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { SDLoc DL(N); EVT ShiftVT = N1.getValueType(); SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); @@ -5862,7 +6392,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) { return Val->getAPIntValue().uge(OpSizeInBits); }; - if (matchUnaryPredicate(N1, MatchShiftTooBig)) + if (ISD::matchUnaryPredicate(N1, MatchShiftTooBig)) return DAG.getUNDEF(VT); // fold (sra x, 0) -> x if (N1C && N1C->isNullValue()) @@ -5897,7 +6427,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); return (c1 + c2).uge(OpSizeInBits); }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), DAG.getConstant(OpSizeInBits - 1, DL, ShiftVT)); @@ -5908,7 +6438,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); return (c1 + c2).ult(OpSizeInBits); }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), Sum); } @@ -6026,7 +6556,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) { return Val->getAPIntValue().uge(OpSizeInBits); }; - if (matchUnaryPredicate(N1, MatchShiftTooBig)) + if (ISD::matchUnaryPredicate(N1, MatchShiftTooBig)) return DAG.getUNDEF(VT); // fold (srl x, 0) -> x if (N1C && N1C->isNullValue()) @@ -6049,7 +6579,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); return (c1 + c2).uge(OpSizeInBits); }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) return DAG.getConstant(0, SDLoc(N), VT); auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS, @@ -6059,7 +6589,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); return (c1 + c2).ult(OpSizeInBits); }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { SDLoc DL(N); EVT ShiftVT = N1.getValueType(); SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); @@ -6270,6 +6800,13 @@ SDValue DAGCombiner::visitCTLZ(SDNode *N) { // fold (ctlz c1) -> c2 if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::CTLZ, SDLoc(N), VT, N0); + + // If the value is known never to be zero, switch to the undef version. + if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT)) { + if (DAG.isKnownNeverZero(N0)) + return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0); + } + return SDValue(); } @@ -6290,6 +6827,13 @@ SDValue DAGCombiner::visitCTTZ(SDNode *N) { // fold (cttz c1) -> c2 if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::CTTZ, SDLoc(N), VT, N0); + + // If the value is known never to be zero, switch to the undef version. + if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT)) { + if (DAG.isKnownNeverZero(N0)) + return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0); + } + return SDValue(); } @@ -6313,7 +6857,7 @@ SDValue DAGCombiner::visitCTPOP(SDNode *N) { return SDValue(); } -/// \brief Generate Min/Max node +/// Generate Min/Max node static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, SDValue RHS, SDValue True, SDValue False, ISD::CondCode CC, const TargetLowering &TLI, @@ -6428,9 +6972,9 @@ SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) { // in another basic block or it could require searching a complicated // expression. if (CondVT.isInteger() && - TLI.getBooleanContents(false, true) == + TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) == TargetLowering::ZeroOrOneBooleanContent && - TLI.getBooleanContents(false, false) == + TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) == TargetLowering::ZeroOrOneBooleanContent && C1->isNullValue() && C2->isOne()) { SDValue NotCond = @@ -6559,15 +7103,10 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { } } - // select (xor Cond, 1), X, Y -> select Cond, Y, X if (VT0 == MVT::i1) { - if (N0->getOpcode() == ISD::XOR) { - if (auto *C = dyn_cast<ConstantSDNode>(N0->getOperand(1))) { - SDValue Cond0 = N0->getOperand(0); - if (C->isOne()) - return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N2, N1); - } - } + // select (not Cond), N1, N2 -> select Cond, N2, N1 + if (isBitwiseNot(N0)) + return DAG.getNode(ISD::SELECT, DL, VT, N0->getOperand(0), N2, N1); } // fold selects based on a setcc into other things, such as min/max/abs @@ -6711,6 +7250,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) { SDValue DataLo, DataHi; std::tie(DataLo, DataHi) = DAG.SplitVector(Data, DL); + SDValue Scale = MSC->getScale(); SDValue BasePtr = MSC->getBasePtr(); SDValue IndexLo, IndexHi; std::tie(IndexLo, IndexHi) = DAG.SplitVector(MSC->getIndex(), DL); @@ -6720,11 +7260,11 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) { MachineMemOperand::MOStore, LoMemVT.getStoreSize(), Alignment, MSC->getAAInfo(), MSC->getRanges()); - SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo }; + SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo, Scale }; Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(), DL, OpsLo, MMO); - SDValue OpsHi[] = {Chain, DataHi, MaskHi, BasePtr, IndexHi}; + SDValue OpsHi[] = { Chain, DataHi, MaskHi, BasePtr, IndexHi, Scale }; Hi = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), DL, OpsHi, MMO); @@ -6785,12 +7325,12 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) { Ptr = TLI.IncrementMemoryAddress(Ptr, MaskLo, DL, LoMemVT, DAG, MST->isCompressingStore()); + unsigned HiOffset = LoMemVT.getStoreSize(); - MMO = DAG.getMachineFunction(). - getMachineMemOperand(MST->getPointerInfo(), - MachineMemOperand::MOStore, HiMemVT.getStoreSize(), - SecondHalfAlignment, MST->getAAInfo(), - MST->getRanges()); + MMO = DAG.getMachineFunction().getMachineMemOperand( + MST->getPointerInfo().getWithOffset(HiOffset), + MachineMemOperand::MOStore, HiMemVT.getStoreSize(), SecondHalfAlignment, + MST->getAAInfo(), MST->getRanges()); Hi = DAG.getMaskedStore(Chain, DL, DataHi, Ptr, MaskHi, HiMemVT, MMO, MST->isTruncatingStore(), @@ -6844,6 +7384,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) { EVT LoMemVT, HiMemVT; std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + SDValue Scale = MGT->getScale(); SDValue BasePtr = MGT->getBasePtr(); SDValue Index = MGT->getIndex(); SDValue IndexLo, IndexHi; @@ -6854,13 +7395,13 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) { MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), Alignment, MGT->getAAInfo(), MGT->getRanges()); - SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo }; + SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo, Scale }; Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, DL, OpsLo, - MMO); + MMO); - SDValue OpsHi[] = {Chain, Src0Hi, MaskHi, BasePtr, IndexHi}; + SDValue OpsHi[] = { Chain, Src0Hi, MaskHi, BasePtr, IndexHi, Scale }; Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, DL, OpsHi, - MMO); + MMO); AddToWorklist(Lo.getNode()); AddToWorklist(Hi.getNode()); @@ -6934,11 +7475,12 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) { Ptr = TLI.IncrementMemoryAddress(Ptr, MaskLo, DL, LoMemVT, DAG, MLD->isExpandingLoad()); + unsigned HiOffset = LoMemVT.getStoreSize(); - MMO = DAG.getMachineFunction(). - getMachineMemOperand(MLD->getPointerInfo(), - MachineMemOperand::MOLoad, HiMemVT.getStoreSize(), - SecondHalfAlignment, MLD->getAAInfo(), MLD->getRanges()); + MMO = DAG.getMachineFunction().getMachineMemOperand( + MLD->getPointerInfo().getWithOffset(HiOffset), + MachineMemOperand::MOLoad, HiMemVT.getStoreSize(), SecondHalfAlignment, + MLD->getAAInfo(), MLD->getRanges()); Hi = DAG.getMaskedLoad(HiVT, DL, Chain, Ptr, MaskHi, Src0Hi, HiMemVT, MMO, ISD::NON_EXTLOAD, MLD->isExpandingLoad()); @@ -7056,6 +7598,36 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { AddToWorklist(Add.getNode()); return DAG.getNode(ISD::XOR, DL, VT, Add, Shift); } + + // If this select has a condition (setcc) with narrower operands than the + // select, try to widen the compare to match the select width. + // TODO: This should be extended to handle any constant. + // TODO: This could be extended to handle non-loading patterns, but that + // requires thorough testing to avoid regressions. + if (isNullConstantOrNullSplatConstant(RHS)) { + EVT NarrowVT = LHS.getValueType(); + EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger(); + EVT SetCCVT = getSetCCResultType(LHS.getValueType()); + unsigned SetCCWidth = SetCCVT.getScalarSizeInBits(); + unsigned WideWidth = WideVT.getScalarSizeInBits(); + bool IsSigned = isSignedIntSetCC(CC); + auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD; + if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() && + SetCCWidth != 1 && SetCCWidth < WideWidth && + TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) && + TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) { + // Both compare operands can be widened for free. The LHS can use an + // extended load, and the RHS is a constant: + // vselect (ext (setcc load(X), C)), N1, N2 --> + // vselect (setcc extload(X), C'), N1, N2 + auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS); + SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS); + EVT WideSetCCVT = getSetCCResultType(WideVT); + SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC); + return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2); + } + } } if (SimplifySelectOps(N, N1, N2)) @@ -7127,22 +7699,33 @@ SDValue DAGCombiner::visitSELECT_CC(SDNode *N) { } SDValue DAGCombiner::visitSETCC(SDNode *N) { - return SimplifySetCC(N->getValueType(0), N->getOperand(0), N->getOperand(1), - cast<CondCodeSDNode>(N->getOperand(2))->get(), - SDLoc(N)); -} + // setcc is very commonly used as an argument to brcond. This pattern + // also lend itself to numerous combines and, as a result, it is desired + // we keep the argument to a brcond as a setcc as much as possible. + bool PreferSetCC = + N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND; -SDValue DAGCombiner::visitSETCCE(SDNode *N) { - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - SDValue Carry = N->getOperand(2); - SDValue Cond = N->getOperand(3); + SDValue Combined = SimplifySetCC( + N->getValueType(0), N->getOperand(0), N->getOperand(1), + cast<CondCodeSDNode>(N->getOperand(2))->get(), SDLoc(N), !PreferSetCC); - // If Carry is false, fold to a regular SETCC. - if (Carry.getOpcode() == ISD::CARRY_FALSE) - return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond); + if (!Combined) + return SDValue(); - return SDValue(); + // If we prefer to have a setcc, and we don't, we'll try our best to + // recreate one using rebuildSetCC. + if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) { + SDValue NewSetCC = rebuildSetCC(Combined); + + // We don't have anything interesting to combine to. + if (NewSetCC.getNode() == N) + return SDValue(); + + if (NewSetCC) + return NewSetCC; + } + + return Combined; } SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) { @@ -7222,12 +7805,12 @@ static SDNode *tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI, // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))" // transformation. Returns true if extension are possible and the above // mentioned transformation is profitable. -static bool ExtendUsesToFormExtLoad(SDNode *N, SDValue N0, +static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0, unsigned ExtOpc, SmallVectorImpl<SDNode *> &ExtendNodes, const TargetLowering &TLI) { bool HasCopyToRegUses = false; - bool isTruncFree = TLI.isTruncateFree(N->getValueType(0), N0.getValueType()); + bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType()); for (SDNode::use_iterator UI = N0.getNode()->use_begin(), UE = N0.getNode()->use_end(); UI != UE; ++UI) { @@ -7283,16 +7866,16 @@ static bool ExtendUsesToFormExtLoad(SDNode *N, SDValue N0, } void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs, - SDValue Trunc, SDValue ExtLoad, - const SDLoc &DL, ISD::NodeType ExtType) { + SDValue OrigLoad, SDValue ExtLoad, + ISD::NodeType ExtType) { // Extend SetCC uses if necessary. - for (unsigned i = 0, e = SetCCs.size(); i != e; ++i) { - SDNode *SetCC = SetCCs[i]; + SDLoc DL(ExtLoad); + for (SDNode *SetCC : SetCCs) { SmallVector<SDValue, 4> Ops; for (unsigned j = 0; j != 2; ++j) { SDValue SOp = SetCC->getOperand(j); - if (SOp == Trunc) + if (SOp == OrigLoad) Ops.push_back(ExtLoad); else Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp)); @@ -7341,7 +7924,7 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) { return SDValue(); SmallVector<SDNode *, 4> SetCCs; - if (!ExtendUsesToFormExtLoad(N, N0, N->getOpcode(), SetCCs, TLI)) + if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI)) return SDValue(); ISD::LoadExtType ExtType = @@ -7372,7 +7955,7 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) { const unsigned Align = MinAlign(LN0->getAlignment(), Offset); SDValue SplitLoad = DAG.getExtLoad( - ExtType, DL, SplitDstVT, LN0->getChain(), BasePtr, + ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(), BasePtr, LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, Align, LN0->getMemOperand()->getFlags(), LN0->getAAInfo()); @@ -7395,12 +7978,82 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) { // with a truncate of the concatenated sextloaded vectors. SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue); + ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode()); CombineTo(N0.getNode(), Trunc, NewChain); - ExtendSetCCUses(SetCCs, Trunc, NewValue, DL, - (ISD::NodeType)N->getOpcode()); return SDValue(N, 0); // Return N so it doesn't get rechecked! } +// fold (zext (and/or/xor (shl/shr (load x), cst), cst)) -> +// (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst)) +SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) { + assert(N->getOpcode() == ISD::ZERO_EXTEND); + EVT VT = N->getValueType(0); + + // and/or/xor + SDValue N0 = N->getOperand(0); + if (!(N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR || + N0.getOpcode() == ISD::XOR) || + N0.getOperand(1).getOpcode() != ISD::Constant || + (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT))) + return SDValue(); + + // shl/shr + SDValue N1 = N0->getOperand(0); + if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) || + N1.getOperand(1).getOpcode() != ISD::Constant || + (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT))) + return SDValue(); + + // load + if (!isa<LoadSDNode>(N1.getOperand(0))) + return SDValue(); + LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0)); + EVT MemVT = Load->getMemoryVT(); + if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) || + Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed()) + return SDValue(); + + + // If the shift op is SHL, the logic op must be AND, otherwise the result + // will be wrong. + if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND) + return SDValue(); + + if (!N0.hasOneUse() || !N1.hasOneUse()) + return SDValue(); + + SmallVector<SDNode*, 4> SetCCs; + if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0), + ISD::ZERO_EXTEND, SetCCs, TLI)) + return SDValue(); + + // Actually do the transformation. + SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT, + Load->getChain(), Load->getBasePtr(), + Load->getMemoryVT(), Load->getMemOperand()); + + SDLoc DL1(N1); + SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad, + N1.getOperand(1)); + + APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); + Mask = Mask.zext(VT.getSizeInBits()); + SDLoc DL0(N0); + SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift, + DAG.getConstant(Mask, DL0, VT)); + + ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND); + CombineTo(N, And); + if (SDValue(Load, 0).hasOneUse()) { + DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1)); + } else { + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load), + Load->getValueType(0), ExtLoad); + CombineTo(Load, Trunc, ExtLoad.getValue(1)); + } + return SDValue(N,0); // Return N so it doesn't get rechecked! +} + /// If we're narrowing or widening the result of a vector select and the final /// size is the same size as a setcc (compare) feeding the select, then try to /// apply the cast operation to the select's operands because matching vector @@ -7446,6 +8099,106 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) { return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB); } +// fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x))) +// fold ([s|z]ext ( extload x)) -> ([s|z]ext (truncate ([s|z]extload x))) +static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner, + const TargetLowering &TLI, EVT VT, + bool LegalOperations, SDNode *N, + SDValue N0, ISD::LoadExtType ExtLoadType) { + SDNode *N0Node = N0.getNode(); + bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node) + : ISD::isZEXTLoad(N0Node); + if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) || + !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse()) + return {}; + + LoadSDNode *LN0 = cast<LoadSDNode>(N0); + EVT MemVT = LN0->getMemoryVT(); + if ((LegalOperations || LN0->isVolatile()) && + !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT)) + return {}; + + SDValue ExtLoad = + DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(), + LN0->getBasePtr(), MemVT, LN0->getMemOperand()); + Combiner.CombineTo(N, ExtLoad); + DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1)); + return SDValue(N, 0); // Return N so it doesn't get rechecked! +} + +// fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x))) +// Only generate vector extloads when 1) they're legal, and 2) they are +// deemed desirable by the target. +static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner, + const TargetLowering &TLI, EVT VT, + bool LegalOperations, SDNode *N, SDValue N0, + ISD::LoadExtType ExtLoadType, + ISD::NodeType ExtOpc) { + if (!ISD::isNON_EXTLoad(N0.getNode()) || + !ISD::isUNINDEXEDLoad(N0.getNode()) || + ((LegalOperations || VT.isVector() || + cast<LoadSDNode>(N0)->isVolatile()) && + !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType()))) + return {}; + + bool DoXform = true; + SmallVector<SDNode *, 4> SetCCs; + if (!N0.hasOneUse()) + DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI); + if (VT.isVector()) + DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0)); + if (!DoXform) + return {}; + + LoadSDNode *LN0 = cast<LoadSDNode>(N0); + SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(), + LN0->getBasePtr(), N0.getValueType(), + LN0->getMemOperand()); + Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc); + // If the load value is used only by N, replace it via CombineTo N. + bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse(); + Combiner.CombineTo(N, ExtLoad); + if (NoReplaceTrunc) { + DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1)); + } else { + SDValue Trunc = + DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad); + Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1)); + } + return SDValue(N, 0); // Return N so it doesn't get rechecked! +} + +static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG, + bool LegalOperations) { + assert((N->getOpcode() == ISD::SIGN_EXTEND || + N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext"); + + SDValue SetCC = N->getOperand(0); + if (LegalOperations || SetCC.getOpcode() != ISD::SETCC || + !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1) + return SDValue(); + + SDValue X = SetCC.getOperand(0); + SDValue Ones = SetCC.getOperand(1); + ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get(); + EVT VT = N->getValueType(0); + EVT XVT = X.getValueType(); + // setge X, C is canonicalized to setgt, so we do not need to match that + // pattern. The setlt sibling is folded in SimplifySelectCC() because it does + // not require the 'not' op. + if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) { + // Invert and smear/shift the sign bit: + // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1) + // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1) + SDLoc DL(N); + SDValue NotX = DAG.getNOT(DL, X, VT); + SDValue ShiftAmount = DAG.getConstant(VT.getSizeInBits() - 1, DL, VT); + auto ShiftOpcode = N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL; + return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount); + } + return SDValue(); +} + SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); @@ -7510,62 +8263,21 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { } } - // fold (sext (load x)) -> (sext (truncate (sextload x))) - // Only generate vector extloads when 1) they're legal, and 2) they are - // deemed desirable by the target. - if (ISD::isNON_EXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && - ((!LegalOperations && !VT.isVector() && - !cast<LoadSDNode>(N0)->isVolatile()) || - TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, N0.getValueType()))) { - bool DoXform = true; - SmallVector<SDNode*, 4> SetCCs; - if (!N0.hasOneUse()) - DoXform = ExtendUsesToFormExtLoad(N, N0, ISD::SIGN_EXTEND, SetCCs, TLI); - if (VT.isVector()) - DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0)); - if (DoXform) { - LoadSDNode *LN0 = cast<LoadSDNode>(N0); - SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(), - LN0->getBasePtr(), N0.getValueType(), - LN0->getMemOperand()); - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), - N0.getValueType(), ExtLoad); - ExtendSetCCUses(SetCCs, Trunc, ExtLoad, DL, ISD::SIGN_EXTEND); - // If the load value is used only by N, replace it via CombineTo N. - bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse(); - CombineTo(N, ExtLoad); - if (NoReplaceTrunc) - DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1)); - else - CombineTo(LN0, Trunc, ExtLoad.getValue(1)); - return SDValue(N, 0); - } - } + // Try to simplify (sext (load x)). + if (SDValue foldedExt = + tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0, + ISD::SEXTLOAD, ISD::SIGN_EXTEND)) + return foldedExt; // fold (sext (load x)) to multiple smaller sextloads. // Only on illegal but splittable vectors. if (SDValue ExtLoad = CombineExtLoad(N)) return ExtLoad; - // fold (sext (sextload x)) -> (sext (truncate (sextload x))) - // fold (sext ( extload x)) -> (sext (truncate (sextload x))) - if ((ISD::isSEXTLoad(N0.getNode()) || ISD::isEXTLoad(N0.getNode())) && - ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) { - LoadSDNode *LN0 = cast<LoadSDNode>(N0); - EVT MemVT = LN0->getMemoryVT(); - if ((!LegalOperations && !LN0->isVolatile()) || - TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT)) { - SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(), - LN0->getBasePtr(), MemVT, - LN0->getMemOperand()); - CombineTo(N, ExtLoad); - CombineTo(N0.getNode(), - DAG.getNode(ISD::TRUNCATE, SDLoc(N0), - N0.getValueType(), ExtLoad), - ExtLoad.getValue(1)); - return SDValue(N, 0); // Return N so it doesn't get rechecked! - } - } + // Try to simplify (sext (sextload x)). + if (SDValue foldedExt = tryToFoldExtOfExtload( + DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD)) + return foldedExt; // fold (sext (and/or/xor (load x), cst)) -> // (and/or/xor (sextload x), (sext cst)) @@ -7573,30 +8285,26 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { N0.getOpcode() == ISD::XOR) && isa<LoadSDNode>(N0.getOperand(0)) && N0.getOperand(1).getOpcode() == ISD::Constant && - TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, N0.getValueType()) && (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) { - LoadSDNode *LN0 = cast<LoadSDNode>(N0.getOperand(0)); - if (LN0->getExtensionType() != ISD::ZEXTLOAD && LN0->isUnindexed()) { - bool DoXform = true; + LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0)); + EVT MemVT = LN00->getMemoryVT(); + if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) && + LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) { SmallVector<SDNode*, 4> SetCCs; - if (!N0.hasOneUse()) - DoXform = ExtendUsesToFormExtLoad(N, N0.getOperand(0), ISD::SIGN_EXTEND, - SetCCs, TLI); + bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0), + ISD::SIGN_EXTEND, SetCCs, TLI); if (DoXform) { - SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN0), VT, - LN0->getChain(), LN0->getBasePtr(), - LN0->getMemoryVT(), - LN0->getMemOperand()); + SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT, + LN00->getChain(), LN00->getBasePtr(), + LN00->getMemoryVT(), + LN00->getMemOperand()); APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); Mask = Mask.sext(VT.getSizeInBits()); SDValue And = DAG.getNode(N0.getOpcode(), DL, VT, ExtLoad, DAG.getConstant(Mask, DL, VT)); - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, - SDLoc(N0.getOperand(0)), - N0.getOperand(0).getValueType(), ExtLoad); - ExtendSetCCUses(SetCCs, Trunc, ExtLoad, DL, ISD::SIGN_EXTEND); + ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND); bool NoReplaceTruncAnd = !N0.hasOneUse(); - bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse(); + bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse(); CombineTo(N, And); // If N0 has multiple uses, change other uses as well. if (NoReplaceTruncAnd) { @@ -7604,15 +8312,21 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And); CombineTo(N0.getNode(), TruncAnd); } - if (NoReplaceTrunc) - DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1)); - else - CombineTo(LN0, Trunc, ExtLoad.getValue(1)); + if (NoReplaceTrunc) { + DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1)); + } else { + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00), + LN00->getValueType(0), ExtLoad); + CombineTo(LN00, Trunc, ExtLoad.getValue(1)); + } return SDValue(N,0); // Return N so it doesn't get rechecked! } } } + if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations)) + return V; + if (N0.getOpcode() == ISD::SETCC) { SDValue N00 = N0.getOperand(0); SDValue N01 = N0.getOperand(1); @@ -7659,8 +8373,9 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { // If the type of the setcc is larger (say, i8) then the value of the high // bit depends on getBooleanContents(), so ask TLI for a real "true" value // of the appropriate width. - SDValue ExtTrueVal = (SetCCWidth == 1) ? DAG.getAllOnesConstant(DL, VT) - : TLI.getConstTrueVal(DAG, VT, DL); + SDValue ExtTrueVal = (SetCCWidth == 1) + ? DAG.getAllOnesConstant(DL, VT) + : DAG.getBoolConstant(true, DL, VT, N00VT); SDValue Zero = DAG.getConstant(0, DL, VT); if (SDValue SCC = SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true)) @@ -7777,13 +8492,16 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { // Try to mask before the extension to avoid having to generate a larger mask, // possibly over several sub-vectors. - if (SrcVT.bitsLT(VT)) { + if (SrcVT.bitsLT(VT) && VT.isVector()) { if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) && TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) { SDValue Op = N0.getOperand(0); Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType()); AddToWorklist(Op.getNode()); - return DAG.getZExtOrTrunc(Op, SDLoc(N), VT); + SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT); + // Transfer the debug info; the new node is equivalent to N0. + DAG.transferDbgValues(N0, ZExtOrTrunc); + return ZExtOrTrunc; } } @@ -7815,39 +8533,11 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { X, DAG.getConstant(Mask, DL, VT)); } - // fold (zext (load x)) -> (zext (truncate (zextload x))) - // Only generate vector extloads when 1) they're legal, and 2) they are - // deemed desirable by the target. - if (ISD::isNON_EXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && - ((!LegalOperations && !VT.isVector() && - !cast<LoadSDNode>(N0)->isVolatile()) || - TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, N0.getValueType()))) { - bool DoXform = true; - SmallVector<SDNode*, 4> SetCCs; - if (!N0.hasOneUse()) - DoXform = ExtendUsesToFormExtLoad(N, N0, ISD::ZERO_EXTEND, SetCCs, TLI); - if (VT.isVector()) - DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0)); - if (DoXform) { - LoadSDNode *LN0 = cast<LoadSDNode>(N0); - SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N), VT, - LN0->getChain(), - LN0->getBasePtr(), N0.getValueType(), - LN0->getMemOperand()); - - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), - N0.getValueType(), ExtLoad); - ExtendSetCCUses(SetCCs, Trunc, ExtLoad, SDLoc(N), ISD::ZERO_EXTEND); - // If the load value is used only by N, replace it via CombineTo N. - bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse(); - CombineTo(N, ExtLoad); - if (NoReplaceTrunc) - DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1)); - else - CombineTo(LN0, Trunc, ExtLoad.getValue(1)); - return SDValue(N, 0); // Return N so it doesn't get rechecked! - } - } + // Try to simplify (zext (load x)). + if (SDValue foldedExt = + tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0, + ISD::ZEXTLOAD, ISD::ZERO_EXTEND)) + return foldedExt; // fold (zext (load x)) to multiple smaller zextloads. // Only on illegal but splittable vectors. @@ -7862,10 +8552,11 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { N0.getOpcode() == ISD::XOR) && isa<LoadSDNode>(N0.getOperand(0)) && N0.getOperand(1).getOpcode() == ISD::Constant && - TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, N0.getValueType()) && (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) { - LoadSDNode *LN0 = cast<LoadSDNode>(N0.getOperand(0)); - if (LN0->getExtensionType() != ISD::SEXTLOAD && LN0->isUnindexed()) { + LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0)); + EVT MemVT = LN00->getMemoryVT(); + if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) && + LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) { bool DoXform = true; SmallVector<SDNode*, 4> SetCCs; if (!N0.hasOneUse()) { @@ -7873,29 +8564,26 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { auto *AndC = cast<ConstantSDNode>(N0.getOperand(1)); EVT LoadResultTy = AndC->getValueType(0); EVT ExtVT; - if (isAndLoadExtLoad(AndC, LN0, LoadResultTy, ExtVT)) + if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT)) DoXform = false; } - if (DoXform) - DoXform = ExtendUsesToFormExtLoad(N, N0.getOperand(0), - ISD::ZERO_EXTEND, SetCCs, TLI); } + if (DoXform) + DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0), + ISD::ZERO_EXTEND, SetCCs, TLI); if (DoXform) { - SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN0), VT, - LN0->getChain(), LN0->getBasePtr(), - LN0->getMemoryVT(), - LN0->getMemOperand()); + SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT, + LN00->getChain(), LN00->getBasePtr(), + LN00->getMemoryVT(), + LN00->getMemOperand()); APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); Mask = Mask.zext(VT.getSizeInBits()); SDLoc DL(N); SDValue And = DAG.getNode(N0.getOpcode(), DL, VT, ExtLoad, DAG.getConstant(Mask, DL, VT)); - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, - SDLoc(N0.getOperand(0)), - N0.getOperand(0).getValueType(), ExtLoad); - ExtendSetCCUses(SetCCs, Trunc, ExtLoad, DL, ISD::ZERO_EXTEND); + ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND); bool NoReplaceTruncAnd = !N0.hasOneUse(); - bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse(); + bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse(); CombineTo(N, And); // If N0 has multiple uses, change other uses as well. if (NoReplaceTruncAnd) { @@ -7903,35 +8591,30 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And); CombineTo(N0.getNode(), TruncAnd); } - if (NoReplaceTrunc) - DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1)); - else - CombineTo(LN0, Trunc, ExtLoad.getValue(1)); + if (NoReplaceTrunc) { + DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1)); + } else { + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00), + LN00->getValueType(0), ExtLoad); + CombineTo(LN00, Trunc, ExtLoad.getValue(1)); + } return SDValue(N,0); // Return N so it doesn't get rechecked! } } } - // fold (zext (zextload x)) -> (zext (truncate (zextload x))) - // fold (zext ( extload x)) -> (zext (truncate (zextload x))) - if ((ISD::isZEXTLoad(N0.getNode()) || ISD::isEXTLoad(N0.getNode())) && - ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) { - LoadSDNode *LN0 = cast<LoadSDNode>(N0); - EVT MemVT = LN0->getMemoryVT(); - if ((!LegalOperations && !LN0->isVolatile()) || - TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT)) { - SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N), VT, - LN0->getChain(), - LN0->getBasePtr(), MemVT, - LN0->getMemOperand()); - CombineTo(N, ExtLoad); - CombineTo(N0.getNode(), - DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), - ExtLoad), - ExtLoad.getValue(1)); - return SDValue(N, 0); // Return N so it doesn't get rechecked! - } - } + // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) -> + // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst)) + if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N)) + return ZExtLoad; + + // Try to simplify (zext (zextload x)). + if (SDValue foldedExt = tryToFoldExtOfExtload( + DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD)) + return foldedExt; + + if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations)) + return V; if (N0.getOpcode() == ISD::SETCC) { // Only do this before legalize for now. @@ -8069,24 +8752,25 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { bool DoXform = true; SmallVector<SDNode*, 4> SetCCs; if (!N0.hasOneUse()) - DoXform = ExtendUsesToFormExtLoad(N, N0, ISD::ANY_EXTEND, SetCCs, TLI); + DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, + TLI); if (DoXform) { LoadSDNode *LN0 = cast<LoadSDNode>(N0); SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT, LN0->getChain(), LN0->getBasePtr(), N0.getValueType(), LN0->getMemOperand()); - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), - N0.getValueType(), ExtLoad); - ExtendSetCCUses(SetCCs, Trunc, ExtLoad, SDLoc(N), - ISD::ANY_EXTEND); + ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND); // If the load value is used only by N, replace it via CombineTo N. bool NoReplaceTrunc = N0.hasOneUse(); CombineTo(N, ExtLoad); - if (NoReplaceTrunc) + if (NoReplaceTrunc) { DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1)); - else + } else { + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), + N0.getValueType(), ExtLoad); CombineTo(LN0, Trunc, ExtLoad.getValue(1)); + } return SDValue(N, 0); // Return N so it doesn't get rechecked! } } @@ -8094,9 +8778,8 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { // fold (aext (zextload x)) -> (aext (truncate (zextload x))) // fold (aext (sextload x)) -> (aext (truncate (sextload x))) // fold (aext ( extload x)) -> (aext (truncate (extload x))) - if (N0.getOpcode() == ISD::LOAD && - !ISD::isNON_EXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && - N0.hasOneUse()) { + if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) && + ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) { LoadSDNode *LN0 = cast<LoadSDNode>(N0); ISD::LoadExtType ExtType = LN0->getExtensionType(); EVT MemVT = LN0->getMemoryVT(); @@ -8105,10 +8788,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { VT, LN0->getChain(), LN0->getBasePtr(), MemVT, LN0->getMemOperand()); CombineTo(N, ExtLoad); - CombineTo(N0.getNode(), - DAG.getNode(ISD::TRUNCATE, SDLoc(N0), - N0.getValueType(), ExtLoad), - ExtLoad.getValue(1)); + DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1)); return SDValue(N, 0); // Return N so it doesn't get rechecked! } } @@ -8248,8 +8928,9 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { unsigned ShAmt = 0; if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) { - if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) { - ShAmt = N01->getZExtValue(); + SDValue SRL = N0; + if (auto *ConstShift = dyn_cast<ConstantSDNode>(SRL.getOperand(1))) { + ShAmt = ConstShift->getZExtValue(); unsigned EVTBits = ExtVT.getSizeInBits(); // Is the shift amount a multiple of size of VT? if ((ShAmt & (EVTBits-1)) == 0) { @@ -8262,17 +8943,36 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { // At this point, we must have a load or else we can't do the transform. if (!isa<LoadSDNode>(N0)) return SDValue(); + auto *LN0 = cast<LoadSDNode>(N0); + // Because a SRL must be assumed to *need* to zero-extend the high bits // (as opposed to anyext the high bits), we can't combine the zextload // lowering of SRL and an sextload. - if (cast<LoadSDNode>(N0)->getExtensionType() == ISD::SEXTLOAD) + if (LN0->getExtensionType() == ISD::SEXTLOAD) return SDValue(); // If the shift amount is larger than the input type then we're not // accessing any of the loaded bytes. If the load was a zextload/extload // then the result of the shift+trunc is zero/undef (handled elsewhere). - if (ShAmt >= cast<LoadSDNode>(N0)->getMemoryVT().getSizeInBits()) + if (ShAmt >= LN0->getMemoryVT().getSizeInBits()) return SDValue(); + + // If the SRL is only used by a masking AND, we may be able to adjust + // the ExtVT to make the AND redundant. + SDNode *Mask = *(SRL->use_begin()); + if (Mask->getOpcode() == ISD::AND && + isa<ConstantSDNode>(Mask->getOperand(1))) { + const APInt &ShiftMask = + cast<ConstantSDNode>(Mask->getOperand(1))->getAPIntValue(); + if (ShiftMask.isMask()) { + EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(), + ShiftMask.countTrailingOnes()); + // If the mask is smaller, recompute the type. + if ((ExtVT.getSizeInBits() > MaskedVT.getSizeInBits()) && + TLI.isLoadExtLegal(ExtType, N0.getValueType(), MaskedVT)) + ExtVT = MaskedVT; + } + } } } @@ -8292,7 +8992,7 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { return SDValue(); LoadSDNode *LN0 = cast<LoadSDNode>(N0); - if (!isLegalNarrowLoad(LN0, ExtType, ExtVT, ShAmt)) + if (!isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt)) return SDValue(); // For big endian targets, we need to adjust the offset to the pointer to @@ -8388,7 +9088,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1); } - // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_in_reg x) + // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x) if ((N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG || N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG || N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) && @@ -8762,6 +9462,22 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { return DAG.getNode(N0.getOpcode(), SL, VTs, X, Y, N0.getOperand(2)); } + // fold (truncate (extract_subvector(ext x))) -> + // (extract_subvector x) + // TODO: This can be generalized to cover cases where the truncate and extract + // do not fully cancel each other out. + if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == ISD::SIGN_EXTEND || + N00.getOpcode() == ISD::ZERO_EXTEND || + N00.getOpcode() == ISD::ANY_EXTEND) { + if (N00.getOperand(0)->getValueType(0).getVectorElementType() == + VT.getVectorElementType()) + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT, + N00.getOperand(0), N0.getOperand(1)); + } + } + if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N)) return NewVSel; @@ -8882,17 +9598,17 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { } // If the input is a constant, let getNode fold it. - if (isa<ConstantSDNode>(N0) || isa<ConstantFPSDNode>(N0)) { - // If we can't allow illegal operations, we need to check that this is just - // a fp -> int or int -> conversion and that the resulting operation will - // be legal. - if (!LegalOperations || - (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() && - TLI.isOperationLegal(ISD::ConstantFP, VT)) || - (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() && - TLI.isOperationLegal(ISD::Constant, VT))) - return DAG.getBitcast(VT, N0); - } + // We always need to check that this is just a fp -> int or int -> conversion + // otherwise we will get back N which will confuse the caller into thinking + // we used CombineTo. This can block target combines from running. If we can't + // allowed legal operations, we need to ensure the resulting operation will be + // legal. + // TODO: Maybe we should check that the return value isn't N explicitly? + if ((isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() && + (!LegalOperations || TLI.isOperationLegal(ISD::ConstantFP, VT))) || + (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() && + (!LegalOperations || TLI.isOperationLegal(ISD::Constant, VT)))) + return DAG.getBitcast(VT, N0); // (conv (conv x, t1), t2) -> (conv x, t2) if (N0.getOpcode() == ISD::BITCAST) @@ -9238,7 +9954,7 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { static bool isContractable(SDNode *N) { SDNodeFlags F = N->getFlags(); - return F.hasAllowContract() || F.hasUnsafeAlgebra(); + return F.hasAllowContract() || F.hasAllowReassociation(); } /// Try to perform FMA combining on a given FADD node. @@ -9262,8 +9978,10 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { if (!HasFMAD && !HasFMA) return SDValue(); + SDNodeFlags Flags = N->getFlags(); + bool CanFuse = Options.UnsafeFPMath || isContractable(N); bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast || - Options.UnsafeFPMath || HasFMAD); + CanFuse || HasFMAD); // If the addition is not contractable, do not combine. if (!AllowFusionGlobally && !isContractable(N)) return SDValue(); @@ -9293,14 +10011,14 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { // fold (fadd (fmul x, y), z) -> (fma x, y, z) if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { return DAG.getNode(PreferredFusedOpcode, SL, VT, - N0.getOperand(0), N0.getOperand(1), N1); + N0.getOperand(0), N0.getOperand(1), N1, Flags); } // fold (fadd x, (fmul y, z)) -> (fma y, z, x) // Note: Commutes FADD operands. if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { return DAG.getNode(PreferredFusedOpcode, SL, VT, - N1.getOperand(0), N1.getOperand(1), N0); + N1.getOperand(0), N1.getOperand(1), N0, Flags); } // Look through FP_EXTEND nodes to do more combining. @@ -9314,7 +10032,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), DAG.getNode(ISD::FP_EXTEND, SL, VT, - N00.getOperand(1)), N1); + N00.getOperand(1)), N1, Flags); } } @@ -9328,16 +10046,14 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), DAG.getNode(ISD::FP_EXTEND, SL, VT, - N10.getOperand(1)), N0); + N10.getOperand(1)), N0, Flags); } } // More folding opportunities when target permits. if (Aggressive) { // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) - // FIXME: The UnsafeAlgebra flag should be propagated to FMA/FMAD, but FMF - // are currently only supported on binary nodes. - if (Options.UnsafeFPMath && + if (CanFuse && N0.getOpcode() == PreferredFusedOpcode && N0.getOperand(2).getOpcode() == ISD::FMUL && N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { @@ -9346,13 +10062,11 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(2).getOperand(0), N0.getOperand(2).getOperand(1), - N1)); + N1, Flags), Flags); } // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) - // FIXME: The UnsafeAlgebra flag should be propagated to FMA/FMAD, but FMF - // are currently only supported on binary nodes. - if (Options.UnsafeFPMath && + if (CanFuse && N1->getOpcode() == PreferredFusedOpcode && N1.getOperand(2).getOpcode() == ISD::FMUL && N1->hasOneUse() && N1.getOperand(2)->hasOneUse()) { @@ -9361,19 +10075,20 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(2).getOperand(0), N1.getOperand(2).getOperand(1), - N0)); + N0, Flags), Flags); } // fold (fadd (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y, (fma (fpext u), (fpext v), z)) auto FoldFAddFMAFPExtFMul = [&] ( - SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) { + SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, + SDNodeFlags Flags) { return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y, DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, U), DAG.getNode(ISD::FP_EXTEND, SL, VT, V), - Z)); + Z, Flags), Flags); }; if (N0.getOpcode() == PreferredFusedOpcode) { SDValue N02 = N0.getOperand(2); @@ -9383,7 +10098,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N020.getValueType())) { return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1), N020.getOperand(0), N020.getOperand(1), - N1); + N1, Flags); } } } @@ -9394,14 +10109,15 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { // operation into two double-precision operations, which might not be // interesting for all targets, especially GPUs. auto FoldFAddFPExtFMAFMul = [&] ( - SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) { + SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, + SDNodeFlags Flags) { return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, X), DAG.getNode(ISD::FP_EXTEND, SL, VT, Y), DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, U), DAG.getNode(ISD::FP_EXTEND, SL, VT, V), - Z)); + Z, Flags), Flags); }; if (N0.getOpcode() == ISD::FP_EXTEND) { SDValue N00 = N0.getOperand(0); @@ -9411,7 +10127,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) { return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1), N002.getOperand(0), N002.getOperand(1), - N1); + N1, Flags); } } } @@ -9426,7 +10142,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N120.getValueType())) { return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1), N120.getOperand(0), N120.getOperand(1), - N0); + N0, Flags); } } } @@ -9444,7 +10160,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N10.getValueType())) { return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1), N102.getOperand(0), N102.getOperand(1), - N0); + N0, Flags); } } } @@ -9473,8 +10189,11 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { if (!HasFMAD && !HasFMA) return SDValue(); + const SDNodeFlags Flags = N->getFlags(); + bool CanFuse = Options.UnsafeFPMath || isContractable(N); bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast || - Options.UnsafeFPMath || HasFMAD); + CanFuse || HasFMAD); + // If the subtraction is not contractable, do not combine. if (!AllowFusionGlobally && !isContractable(N)) return SDValue(); @@ -9499,16 +10218,17 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - DAG.getNode(ISD::FNEG, SL, VT, N1)); + DAG.getNode(ISD::FNEG, SL, VT, N1), Flags); } // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x) // Note: Commutes FSUB operands. - if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) + if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), - N1.getOperand(1), N0); + N1.getOperand(1), N0, Flags); + } // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z)) if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) && @@ -9517,7 +10237,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { SDValue N01 = N0.getOperand(0).getOperand(1); return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, N00), N01, - DAG.getNode(ISD::FNEG, SL, VT, N1)); + DAG.getNode(ISD::FNEG, SL, VT, N1), Flags); } // Look through FP_EXTEND nodes to do more combining. @@ -9533,7 +10253,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { N00.getOperand(0)), DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), - DAG.getNode(ISD::FNEG, SL, VT, N1)); + DAG.getNode(ISD::FNEG, SL, VT, N1), Flags); } } @@ -9550,7 +10270,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { N10.getOperand(0))), DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), - N0); + N0, Flags); } } @@ -9572,7 +10292,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { N000.getOperand(0)), DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), - N1)); + N1, Flags)); } } } @@ -9595,7 +10315,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { N000.getOperand(0)), DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), - N1)); + N1, Flags)); } } } @@ -9604,9 +10324,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { if (Aggressive) { // fold (fsub (fma x, y, (fmul u, v)), z) // -> (fma x, y (fma u, v, (fneg z))) - // FIXME: The UnsafeAlgebra flag should be propagated to FMA/FMAD, but FMF - // are currently only supported on binary nodes. - if (Options.UnsafeFPMath && N0.getOpcode() == PreferredFusedOpcode && + if (CanFuse && N0.getOpcode() == PreferredFusedOpcode && isContractableFMUL(N0.getOperand(2)) && N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { return DAG.getNode(PreferredFusedOpcode, SL, VT, @@ -9615,14 +10333,12 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { N0.getOperand(2).getOperand(0), N0.getOperand(2).getOperand(1), DAG.getNode(ISD::FNEG, SL, VT, - N1))); + N1), Flags), Flags); } // fold (fsub x, (fma y, z, (fmul u, v))) // -> (fma (fneg y), z, (fma (fneg u), v, x)) - // FIXME: The UnsafeAlgebra flag should be propagated to FMA/FMAD, but FMF - // are currently only supported on binary nodes. - if (Options.UnsafeFPMath && N1.getOpcode() == PreferredFusedOpcode && + if (CanFuse && N1.getOpcode() == PreferredFusedOpcode && isContractableFMUL(N1.getOperand(2))) { SDValue N20 = N1.getOperand(2).getOperand(0); SDValue N21 = N1.getOperand(2).getOperand(1); @@ -9632,8 +10348,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { N1.getOperand(1), DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, N20), - - N21, N0)); + N21, N0, Flags), Flags); } @@ -9653,7 +10368,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)), DAG.getNode(ISD::FNEG, SL, VT, - N1))); + N1), Flags), Flags); } } } @@ -9681,7 +10396,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)), DAG.getNode(ISD::FNEG, SL, VT, - N1))); + N1), Flags), Flags); } } } @@ -9704,7 +10419,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { VT, N1200)), DAG.getNode(ISD::FP_EXTEND, SL, VT, N1201), - N0)); + N0, Flags), Flags); } } @@ -9735,7 +10450,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { VT, N1020)), DAG.getNode(ISD::FP_EXTEND, SL, VT, N1021), - N0)); + N0, Flags), Flags); } } } @@ -9751,6 +10466,7 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) { SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); + const SDNodeFlags Flags = N->getFlags(); assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation"); @@ -9782,52 +10498,54 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) { // fold (fmul (fadd x, +1.0), y) -> (fma x, y, y) // fold (fmul (fadd x, -1.0), y) -> (fma x, y, (fneg y)) - auto FuseFADD = [&](SDValue X, SDValue Y) { + auto FuseFADD = [&](SDValue X, SDValue Y, const SDNodeFlags Flags) { if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) { auto XC1 = isConstOrConstSplatFP(X.getOperand(1)); if (XC1 && XC1->isExactlyValue(+1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y); + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, + Y, Flags); if (XC1 && XC1->isExactlyValue(-1.0)) return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, - DAG.getNode(ISD::FNEG, SL, VT, Y)); + DAG.getNode(ISD::FNEG, SL, VT, Y), Flags); } return SDValue(); }; - if (SDValue FMA = FuseFADD(N0, N1)) + if (SDValue FMA = FuseFADD(N0, N1, Flags)) return FMA; - if (SDValue FMA = FuseFADD(N1, N0)) + if (SDValue FMA = FuseFADD(N1, N0, Flags)) return FMA; // fold (fmul (fsub +1.0, x), y) -> (fma (fneg x), y, y) // fold (fmul (fsub -1.0, x), y) -> (fma (fneg x), y, (fneg y)) // fold (fmul (fsub x, +1.0), y) -> (fma x, y, (fneg y)) // fold (fmul (fsub x, -1.0), y) -> (fma x, y, y) - auto FuseFSUB = [&](SDValue X, SDValue Y) { + auto FuseFSUB = [&](SDValue X, SDValue Y, const SDNodeFlags Flags) { if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) { auto XC0 = isConstOrConstSplatFP(X.getOperand(0)); if (XC0 && XC0->isExactlyValue(+1.0)) return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, - Y); + Y, Flags); if (XC0 && XC0->isExactlyValue(-1.0)) return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, - DAG.getNode(ISD::FNEG, SL, VT, Y)); + DAG.getNode(ISD::FNEG, SL, VT, Y), Flags); auto XC1 = isConstOrConstSplatFP(X.getOperand(1)); if (XC1 && XC1->isExactlyValue(+1.0)) return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, - DAG.getNode(ISD::FNEG, SL, VT, Y)); + DAG.getNode(ISD::FNEG, SL, VT, Y), Flags); if (XC1 && XC1->isExactlyValue(-1.0)) - return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y); + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, + Y, Flags); } return SDValue(); }; - if (SDValue FMA = FuseFSUB(N0, N1)) + if (SDValue FMA = FuseFSUB(N0, N1, Flags)) return FMA; - if (SDValue FMA = FuseFSUB(N1, N0)) + if (SDValue FMA = FuseFSUB(N1, N0, Flags)) return FMA; return SDValue(); @@ -9889,35 +10607,42 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { return DAG.getNode(ISD::FSUB, DL, VT, N1IsFMul ? N0 : N1, Add, Flags); } - // FIXME: Auto-upgrade the target/function-level option. - if (Options.NoSignedZerosFPMath || N->getFlags().hasNoSignedZeros()) { - // fold (fadd A, 0) -> A - if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1)) - if (N1C->isZero()) - return N0; + ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1); + if (N1C && N1C->isZero()) { + if (N1C->isNegative() || Options.UnsafeFPMath || + Flags.hasNoSignedZeros()) { + // fold (fadd A, 0) -> A + return N0; + } } - // If 'unsafe math' is enabled, fold lots of things. - if (Options.UnsafeFPMath) { - // No FP constant should be created after legalization as Instruction - // Selection pass has a hard time dealing with FP constants. - bool AllowNewConst = (Level < AfterLegalizeDAG); - - // fold (fadd (fadd x, c1), c2) -> (fadd x, (fadd c1, c2)) - if (N1CFP && N0.getOpcode() == ISD::FADD && N0.getNode()->hasOneUse() && - isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) - return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), - DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1, - Flags), - Flags); + // No FP constant should be created after legalization as Instruction + // Selection pass has a hard time dealing with FP constants. + bool AllowNewConst = (Level < AfterLegalizeDAG); + // If 'unsafe math' or nnan is enabled, fold lots of things. + if ((Options.UnsafeFPMath || Flags.hasNoNaNs()) && AllowNewConst) { // If allowed, fold (fadd (fneg x), x) -> 0.0 - if (AllowNewConst && N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1) + if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1) return DAG.getConstantFP(0.0, DL, VT); // If allowed, fold (fadd x, (fneg x)) -> 0.0 - if (AllowNewConst && N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0) + if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0) return DAG.getConstantFP(0.0, DL, VT); + } + + // If 'unsafe math' or reassoc and nsz, fold lots of things. + // TODO: break out portions of the transformations below for which Unsafe is + // considered and which do not require both nsz and reassoc + if ((Options.UnsafeFPMath || + (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) && + AllowNewConst) { + // fadd (fadd x, c1), c2 -> fadd x, c1 + c2 + if (N1CFP && N0.getOpcode() == ISD::FADD && + isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) { + SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1, Flags); + return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC, Flags); + } // We can fold chains of FADD's of the same value into multiplications. // This transform is not safe in general because we are reducing the number @@ -9965,7 +10690,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { } } - if (N0.getOpcode() == ISD::FADD && AllowNewConst) { + if (N0.getOpcode() == ISD::FADD) { bool CFP00 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(0)); // (fadd (fadd x, x), x) -> (fmul x, 3.0) if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) && @@ -9975,7 +10700,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { } } - if (N1.getOpcode() == ISD::FADD && AllowNewConst) { + if (N1.getOpcode() == ISD::FADD) { bool CFP10 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(0)); // (fadd x, (fadd x, x)) -> (fmul x, 3.0) if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) && @@ -9986,8 +10711,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { } // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0) - if (AllowNewConst && - N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD && + if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD && N0.getOperand(0) == N0.getOperand(1) && N1.getOperand(0) == N1.getOperand(1) && N0.getOperand(0) == N1.getOperand(0)) { @@ -10027,15 +10751,23 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; - // fold (fsub A, (fneg B)) -> (fadd A, B) - if (isNegatibleForFree(N1, LegalOperations, TLI, &Options)) - return DAG.getNode(ISD::FADD, DL, VT, N0, - GetNegatedExpression(N1, DAG, LegalOperations), Flags); + // (fsub A, 0) -> A + if (N1CFP && N1CFP->isZero()) { + if (!N1CFP->isNegative() || Options.UnsafeFPMath || + Flags.hasNoSignedZeros()) { + return N0; + } + } + + if (N0 == N1) { + // (fsub x, x) -> 0.0 + if (Options.UnsafeFPMath || Flags.hasNoNaNs()) + return DAG.getConstantFP(0.0f, DL, VT); + } - // FIXME: Auto-upgrade the target/function-level option. - if (Options.NoSignedZerosFPMath || N->getFlags().hasNoSignedZeros()) { - // (fsub 0, B) -> -B - if (N0CFP && N0CFP->isZero()) { + // (fsub 0, B) -> -B + if (N0CFP && N0CFP->isZero()) { + if (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) { if (isNegatibleForFree(N1, LegalOperations, TLI, &Options)) return GetNegatedExpression(N1, DAG, LegalOperations); if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT)) @@ -10043,16 +10775,13 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { } } + // fold (fsub A, (fneg B)) -> (fadd A, B) + if (isNegatibleForFree(N1, LegalOperations, TLI, &Options)) + return DAG.getNode(ISD::FADD, DL, VT, N0, + GetNegatedExpression(N1, DAG, LegalOperations), Flags); + // If 'unsafe math' is enabled, fold lots of things. if (Options.UnsafeFPMath) { - // (fsub A, 0) -> A - if (N1CFP && N1CFP->isZero()) - return N0; - - // (fsub x, x) -> 0.0 - if (N0 == N1) - return DAG.getConstantFP(0.0f, DL, VT); - // (fsub x, (fadd x, y)) -> (fneg y) // (fsub x, (fadd y, x)) -> (fneg y) if (N1.getOpcode() == ISD::FADD) { @@ -10109,12 +10838,15 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; - if (Options.UnsafeFPMath) { + if (Options.UnsafeFPMath || + (Flags.hasNoNaNs() && Flags.hasNoSignedZeros())) { // fold (fmul A, 0) -> 0 if (N1CFP && N1CFP->isZero()) return N1; + } - // fold (fmul (fmul x, c1), c2) -> (fmul x, (fmul c1, c2)) + if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) { + // fmul (fmul X, C1), C2 -> fmul X, C1 * C2 if (N0.getOpcode() == ISD::FMUL) { // Fold scalars or any vector constants (not just splats). // This fold is done in general by InstCombine, but extra fmul insts @@ -10138,13 +10870,10 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { } } - // fold (fmul (fadd x, x), c) -> (fmul x, (fmul 2.0, c)) - // Undo the fmul 2.0, x -> fadd x, x transformation, since if it occurs - // during an early run of DAGCombiner can prevent folding with fmuls - // inserted during lowering. - if (N0.getOpcode() == ISD::FADD && - (N0.getOperand(0) == N0.getOperand(1)) && - N0.hasOneUse()) { + // Match a special-case: we convert X * 2.0 into fadd. + // fmul (fadd X, X), C -> fmul X, 2.0 * C + if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() && + N0.getOperand(0) == N0.getOperand(1)) { const SDValue Two = DAG.getConstantFP(2.0, DL, VT); SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1, Flags); return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts, Flags); @@ -10238,6 +10967,10 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; + // FMA nodes have flags that propagate to the created nodes. + const SDNodeFlags Flags = N->getFlags(); + bool UnsafeFPMath = Options.UnsafeFPMath || isContractable(N); + // Constant fold FMA. if (isa<ConstantFPSDNode>(N0) && isa<ConstantFPSDNode>(N1) && @@ -10245,7 +10978,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2); } - if (Options.UnsafeFPMath) { + if (UnsafeFPMath) { if (N0CFP && N0CFP->isZero()) return N2; if (N1CFP && N1CFP->isZero()) @@ -10262,12 +10995,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { !isConstantFPBuildVectorOrConstantFP(N1)) return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2); - // TODO: FMA nodes should have flags that propagate to the created nodes. - // For now, create a Flags object for use with all unsafe math transforms. - SDNodeFlags Flags; - Flags.setUnsafeAlgebra(true); - - if (Options.UnsafeFPMath) { + if (UnsafeFPMath) { // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2) if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) && isConstantFPBuildVectorOrConstantFP(N1) && @@ -10313,7 +11041,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { } } - if (Options.UnsafeFPMath) { + if (UnsafeFPMath) { // (fma x, c, x) -> (fmul x, (c+1)) if (N1CFP && N0 == N2) { return DAG.getNode(ISD::FMUL, DL, VT, N0, @@ -10420,7 +11148,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; - if (Options.UnsafeFPMath) { + if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) { // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable. if (N1CFP) { // Compute the reciprocal 1.0 / c2. @@ -10529,17 +11257,16 @@ SDValue DAGCombiner::visitFREM(SDNode *N) { } SDValue DAGCombiner::visitFSQRT(SDNode *N) { - if (!DAG.getTarget().Options.UnsafeFPMath) + SDNodeFlags Flags = N->getFlags(); + if (!DAG.getTarget().Options.UnsafeFPMath && + !Flags.hasApproximateFuncs()) return SDValue(); SDValue N0 = N->getOperand(0); if (TLI.isFsqrtCheap(N0, DAG)) return SDValue(); - // TODO: FSQRT nodes should have flags that propagate to the created nodes. - // For now, create a Flags object for use with all unsafe math transforms. - SDNodeFlags Flags; - Flags.setUnsafeAlgebra(true); + // FSQRT nodes have flags that propagate to the created nodes. return buildSqrtEstimate(N0, Flags); } @@ -10607,6 +11334,41 @@ SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) { return SDValue(); } +static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG, + const TargetLowering &TLI) { + // This optimization is guarded by a function attribute because it may produce + // unexpected results. Ie, programs may be relying on the platform-specific + // undefined behavior when the float-to-int conversion overflows. + const Function &F = DAG.getMachineFunction().getFunction(); + Attribute StrictOverflow = F.getFnAttribute("strict-float-cast-overflow"); + if (StrictOverflow.getValueAsString().equals("false")) + return SDValue(); + + // We only do this if the target has legal ftrunc. Otherwise, we'd likely be + // replacing casts with a libcall. We also must be allowed to ignore -0.0 + // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer + // conversions would return +0.0. + // FIXME: We should be able to use node-level FMF here. + // TODO: If strict math, should we use FABS (+ range check for signed cast)? + EVT VT = N->getValueType(0); + if (!TLI.isOperationLegal(ISD::FTRUNC, VT) || + !DAG.getTarget().Options.NoSignedZerosFPMath) + return SDValue(); + + // fptosi/fptoui round towards zero, so converting from FP to integer and + // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X + SDValue N0 = N->getOperand(0); + if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT && + N0.getOperand(0).getValueType() == VT) + return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0)); + + if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT && + N0.getOperand(0).getValueType() == VT) + return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0)); + + return SDValue(); +} + SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); @@ -10658,6 +11420,9 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) { } } + if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI)) + return FTrunc; + return SDValue(); } @@ -10697,6 +11462,9 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) { } } + if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI)) + return FTrunc; + return SDValue(); } @@ -11103,16 +11871,22 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) { N1.getOperand(0), N1.getOperand(1), N2); } - if ((N1.hasOneUse() && N1.getOpcode() == ISD::SRL) || - ((N1.getOpcode() == ISD::TRUNCATE && N1.hasOneUse()) && - (N1.getOperand(0).hasOneUse() && - N1.getOperand(0).getOpcode() == ISD::SRL))) { - SDNode *Trunc = nullptr; - if (N1.getOpcode() == ISD::TRUNCATE) { - // Look pass the truncate. - Trunc = N1.getNode(); - N1 = N1.getOperand(0); - } + if (N1.hasOneUse()) { + if (SDValue NewN1 = rebuildSetCC(N1)) + return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain, NewN1, N2); + } + + return SDValue(); +} + +SDValue DAGCombiner::rebuildSetCC(SDValue N) { + if (N.getOpcode() == ISD::SRL || + (N.getOpcode() == ISD::TRUNCATE && + (N.getOperand(0).hasOneUse() && + N.getOperand(0).getOpcode() == ISD::SRL))) { + // Look pass the truncate. + if (N.getOpcode() == ISD::TRUNCATE) + N = N.getOperand(0); // Match this pattern so that we can generate simpler code: // @@ -11131,74 +11905,55 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) { // This applies only when the AND constant value has one bit set and the // SRL constant is equal to the log2 of the AND constant. The back-end is // smart enough to convert the result into a TEST/JMP sequence. - SDValue Op0 = N1.getOperand(0); - SDValue Op1 = N1.getOperand(1); + SDValue Op0 = N.getOperand(0); + SDValue Op1 = N.getOperand(1); - if (Op0.getOpcode() == ISD::AND && - Op1.getOpcode() == ISD::Constant) { + if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) { SDValue AndOp1 = Op0.getOperand(1); if (AndOp1.getOpcode() == ISD::Constant) { const APInt &AndConst = cast<ConstantSDNode>(AndOp1)->getAPIntValue(); if (AndConst.isPowerOf2() && - cast<ConstantSDNode>(Op1)->getAPIntValue()==AndConst.logBase2()) { + cast<ConstantSDNode>(Op1)->getAPIntValue() == AndConst.logBase2()) { SDLoc DL(N); - SDValue SetCC = - DAG.getSetCC(DL, - getSetCCResultType(Op0.getValueType()), - Op0, DAG.getConstant(0, DL, Op0.getValueType()), - ISD::SETNE); - - SDValue NewBRCond = DAG.getNode(ISD::BRCOND, DL, - MVT::Other, Chain, SetCC, N2); - // Don't add the new BRCond into the worklist or else SimplifySelectCC - // will convert it back to (X & C1) >> C2. - CombineTo(N, NewBRCond, false); - // Truncate is dead. - if (Trunc) - deleteAndRecombine(Trunc); - // Replace the uses of SRL with SETCC - WorklistRemover DeadNodes(*this); - DAG.ReplaceAllUsesOfValueWith(N1, SetCC); - deleteAndRecombine(N1.getNode()); - return SDValue(N, 0); // Return N so it doesn't get rechecked! + return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()), + Op0, DAG.getConstant(0, DL, Op0.getValueType()), + ISD::SETNE); } } } - - if (Trunc) - // Restore N1 if the above transformation doesn't match. - N1 = N->getOperand(1); } // Transform br(xor(x, y)) -> br(x != y) // Transform br(xor(xor(x,y), 1)) -> br (x == y) - if (N1.hasOneUse() && N1.getOpcode() == ISD::XOR) { - SDNode *TheXor = N1.getNode(); + if (N.getOpcode() == ISD::XOR) { + // Because we may call this on a speculatively constructed + // SimplifiedSetCC Node, we need to simplify this node first. + // Ideally this should be folded into SimplifySetCC and not + // here. For now, grab a handle to N so we don't lose it from + // replacements interal to the visit. + HandleSDNode XORHandle(N); + while (N.getOpcode() == ISD::XOR) { + SDValue Tmp = visitXOR(N.getNode()); + // No simplification done. + if (!Tmp.getNode()) + break; + // Returning N is form in-visit replacement that may invalidated + // N. Grab value from Handle. + if (Tmp.getNode() == N.getNode()) + N = XORHandle.getValue(); + else // Node simplified. Try simplifying again. + N = Tmp; + } + + if (N.getOpcode() != ISD::XOR) + return N; + + SDNode *TheXor = N.getNode(); + SDValue Op0 = TheXor->getOperand(0); SDValue Op1 = TheXor->getOperand(1); - if (Op0.getOpcode() == Op1.getOpcode()) { - // Avoid missing important xor optimizations. - if (SDValue Tmp = visitXOR(TheXor)) { - if (Tmp.getNode() != TheXor) { - DEBUG(dbgs() << "\nReplacing.8 "; - TheXor->dump(&DAG); - dbgs() << "\nWith: "; - Tmp.getNode()->dump(&DAG); - dbgs() << '\n'); - WorklistRemover DeadNodes(*this); - DAG.ReplaceAllUsesOfValueWith(N1, Tmp); - deleteAndRecombine(TheXor); - return DAG.getNode(ISD::BRCOND, SDLoc(N), - MVT::Other, Chain, Tmp, N2); - } - - // visitXOR has changed XOR's operands or replaced the XOR completely, - // bail out. - return SDValue(N, 0); - } - } if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) { bool Equal = false; @@ -11208,19 +11963,12 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) { Equal = true; } - EVT SetCCVT = N1.getValueType(); + EVT SetCCVT = N.getValueType(); if (LegalTypes) SetCCVT = getSetCCResultType(SetCCVT); - SDValue SetCC = DAG.getSetCC(SDLoc(TheXor), - SetCCVT, - Op0, Op1, - Equal ? ISD::SETEQ : ISD::SETNE); // Replace the uses of XOR with SETCC - WorklistRemover DeadNodes(*this); - DAG.ReplaceAllUsesOfValueWith(N1, SetCC); - deleteAndRecombine(N1.getNode()); - return DAG.getNode(ISD::BRCOND, SDLoc(N), - MVT::Other, Chain, SetCC, N2); + return DAG.getSetCC(SDLoc(TheXor), SetCCVT, Op0, Op1, + Equal ? ISD::SETEQ : ISD::SETNE); } } @@ -11452,11 +12200,8 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { BasePtr, Offset, AM); ++PreIndexedNodes; ++NodesCombined; - DEBUG(dbgs() << "\nReplacing.4 "; - N->dump(&DAG); - dbgs() << "\nWith: "; - Result.getNode()->dump(&DAG); - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: "; + Result.getNode()->dump(&DAG); dbgs() << '\n'); WorklistRemover DeadNodes(*this); if (isLoad) { DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0)); @@ -11621,11 +12366,9 @@ bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) { BasePtr, Offset, AM); ++PostIndexedNodes; ++NodesCombined; - DEBUG(dbgs() << "\nReplacing.5 "; - N->dump(&DAG); - dbgs() << "\nWith: "; - Result.getNode()->dump(&DAG); - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); + dbgs() << "\nWith: "; Result.getNode()->dump(&DAG); + dbgs() << '\n'); WorklistRemover DeadNodes(*this); if (isLoad) { DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0)); @@ -11649,7 +12392,7 @@ bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) { return false; } -/// \brief Return the base-pointer arithmetic from an indexed \p LD. +/// Return the base-pointer arithmetic from an indexed \p LD. SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) { ISD::MemIndexedMode AM = LD->getAddressingMode(); assert(AM != ISD::UNINDEXED); @@ -11691,11 +12434,9 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { // v3 = add v2, c // Now we replace use of chain2 with chain1. This makes the second load // isomorphic to the one we are deleting, and thus makes this load live. - DEBUG(dbgs() << "\nReplacing.6 "; - N->dump(&DAG); - dbgs() << "\nWith chain: "; - Chain.getNode()->dump(&DAG); - dbgs() << "\n"); + LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG); + dbgs() << "\nWith chain: "; Chain.getNode()->dump(&DAG); + dbgs() << "\n"); WorklistRemover DeadNodes(*this); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain); AddUsersToWorklist(Chain.getNode()); @@ -11726,11 +12467,9 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { AddUsersToWorklist(N); } else Index = DAG.getUNDEF(N->getValueType(1)); - DEBUG(dbgs() << "\nReplacing.7 "; - N->dump(&DAG); - dbgs() << "\nWith: "; - Undef.getNode()->dump(&DAG); - dbgs() << " and 2 other values\n"); + LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG); + dbgs() << "\nWith: "; Undef.getNode()->dump(&DAG); + dbgs() << " and 2 other values\n"); WorklistRemover DeadNodes(*this); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index); @@ -11758,13 +12497,14 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { // Try to infer better alignment information than the load already has. if (OptLevel != CodeGenOpt::None && LD->isUnindexed()) { if (unsigned Align = DAG.InferPtrAlignment(Ptr)) { - if (Align > LD->getMemOperand()->getBaseAlignment()) { + if (Align > LD->getAlignment() && LD->getSrcValueOffset() % Align == 0) { SDValue NewLoad = DAG.getExtLoad( LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr, LD->getPointerInfo(), LD->getMemoryVT(), Align, LD->getMemOperand()->getFlags(), LD->getAAInfo()); - if (NewLoad.getNode() != N) - return CombineTo(N, NewLoad, SDValue(NewLoad.getNode(), 1), true); + // NewLoad will always be N as we are only refining the alignment + assert(NewLoad.getNode() == N); + (void)NewLoad; } } } @@ -11811,7 +12551,7 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { namespace { -/// \brief Helper structure used to slice a load in smaller loads. +/// Helper structure used to slice a load in smaller loads. /// Basically a slice is obtained from the following sequence: /// Origin = load Ty1, Base /// Shift = srl Ty1 Origin, CstTy Amount @@ -11824,7 +12564,7 @@ namespace { /// SliceTy is deduced from the number of bits that are actually used to /// build Inst. struct LoadedSlice { - /// \brief Helper structure used to compute the cost of a slice. + /// Helper structure used to compute the cost of a slice. struct Cost { /// Are we optimizing for code size. bool ForCodeSize; @@ -11838,7 +12578,7 @@ struct LoadedSlice { Cost(bool ForCodeSize = false) : ForCodeSize(ForCodeSize) {} - /// \brief Get the cost of one isolated slice. + /// Get the cost of one isolated slice. Cost(const LoadedSlice &LS, bool ForCodeSize = false) : ForCodeSize(ForCodeSize), Loads(1) { EVT TruncType = LS.Inst->getValueType(0); @@ -11848,7 +12588,7 @@ struct LoadedSlice { ZExts = 1; } - /// \brief Account for slicing gain in the current cost. + /// Account for slicing gain in the current cost. /// Slicing provide a few gains like removing a shift or a /// truncate. This method allows to grow the cost of the original /// load with the gain from this slice. @@ -11921,7 +12661,7 @@ struct LoadedSlice { unsigned Shift = 0, SelectionDAG *DAG = nullptr) : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {} - /// \brief Get the bits used in a chunk of bits \p BitWidth large. + /// Get the bits used in a chunk of bits \p BitWidth large. /// \return Result is \p BitWidth and has used bits set to 1 and /// not used bits set to 0. APInt getUsedBits() const { @@ -11941,14 +12681,14 @@ struct LoadedSlice { return UsedBits; } - /// \brief Get the size of the slice to be loaded in bytes. + /// Get the size of the slice to be loaded in bytes. unsigned getLoadedSize() const { unsigned SliceSize = getUsedBits().countPopulation(); assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte."); return SliceSize / 8; } - /// \brief Get the type that will be loaded for this slice. + /// Get the type that will be loaded for this slice. /// Note: This may not be the final type for the slice. EVT getLoadedType() const { assert(DAG && "Missing context"); @@ -11956,7 +12696,7 @@ struct LoadedSlice { return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8); } - /// \brief Get the alignment of the load used for this slice. + /// Get the alignment of the load used for this slice. unsigned getAlignment() const { unsigned Alignment = Origin->getAlignment(); unsigned Offset = getOffsetFromBase(); @@ -11965,7 +12705,7 @@ struct LoadedSlice { return Alignment; } - /// \brief Check if this slice can be rewritten with legal operations. + /// Check if this slice can be rewritten with legal operations. bool isLegal() const { // An invalid slice is not legal. if (!Origin || !Inst || !DAG) @@ -12009,7 +12749,7 @@ struct LoadedSlice { return true; } - /// \brief Get the offset in bytes of this slice in the original chunk of + /// Get the offset in bytes of this slice in the original chunk of /// bits. /// \pre DAG != nullptr. uint64_t getOffsetFromBase() const { @@ -12030,7 +12770,7 @@ struct LoadedSlice { return Offset; } - /// \brief Generate the sequence of instructions to load the slice + /// Generate the sequence of instructions to load the slice /// represented by this object and redirect the uses of this slice to /// this new sequence of instructions. /// \pre this->Inst && this->Origin are valid Instructions and this @@ -12068,7 +12808,7 @@ struct LoadedSlice { return LastInst; } - /// \brief Check if this slice can be merged with an expensive cross register + /// Check if this slice can be merged with an expensive cross register /// bank copy. E.g., /// i = load i32 /// f = bitcast i32 i to float @@ -12117,7 +12857,7 @@ struct LoadedSlice { } // end anonymous namespace -/// \brief Check that all bits set in \p UsedBits form a dense region, i.e., +/// Check that all bits set in \p UsedBits form a dense region, i.e., /// \p UsedBits looks like 0..0 1..1 0..0. static bool areUsedBitsDense(const APInt &UsedBits) { // If all the bits are one, this is dense! @@ -12133,7 +12873,7 @@ static bool areUsedBitsDense(const APInt &UsedBits) { return NarrowedUsedBits.isAllOnesValue(); } -/// \brief Check whether or not \p First and \p Second are next to each other +/// Check whether or not \p First and \p Second are next to each other /// in memory. This means that there is no hole between the bits loaded /// by \p First and the bits loaded by \p Second. static bool areSlicesNextToEachOther(const LoadedSlice &First, @@ -12147,7 +12887,7 @@ static bool areSlicesNextToEachOther(const LoadedSlice &First, return areUsedBitsDense(UsedBits); } -/// \brief Adjust the \p GlobalLSCost according to the target +/// Adjust the \p GlobalLSCost according to the target /// paring capabilities and the layout of the slices. /// \pre \p GlobalLSCost should account for at least as many loads as /// there is in the slices in \p LoadedSlices. @@ -12160,8 +12900,8 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices, // Sort the slices so that elements that are likely to be next to each // other in memory are next to each other in the list. - std::sort(LoadedSlices.begin(), LoadedSlices.end(), - [](const LoadedSlice &LHS, const LoadedSlice &RHS) { + llvm::sort(LoadedSlices.begin(), LoadedSlices.end(), + [](const LoadedSlice &LHS, const LoadedSlice &RHS) { assert(LHS.Origin == RHS.Origin && "Different bases not implemented."); return LHS.getOffsetFromBase() < RHS.getOffsetFromBase(); }); @@ -12208,7 +12948,7 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices, } } -/// \brief Check the profitability of all involved LoadedSlice. +/// Check the profitability of all involved LoadedSlice. /// Currently, it is considered profitable if there is exactly two /// involved slices (1) which are (2) next to each other in memory, and /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3). @@ -12252,7 +12992,7 @@ static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices, return OrigCost > GlobalSlicingCost; } -/// \brief If the given load, \p LI, is used only by trunc or trunc(lshr) +/// If the given load, \p LI, is used only by trunc or trunc(lshr) /// operations, split it in the various pieces being extracted. /// /// This sort of thing is introduced by SROA. @@ -12371,22 +13111,6 @@ CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) { LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0)); if (LD->getBasePtr() != Ptr) return Result; // Not from same pointer. - // The store should be chained directly to the load or be an operand of a - // tokenfactor. - if (LD == Chain.getNode()) - ; // ok. - else if (Chain->getOpcode() != ISD::TokenFactor) - return Result; // Fail. - else { - bool isOk = false; - for (const SDValue &ChainOp : Chain->op_values()) - if (ChainOp.getNode() == LD) { - isOk = true; - break; - } - if (!isOk) return Result; - } - // This only handles simple types. if (V.getValueType() != MVT::i16 && V.getValueType() != MVT::i32 && @@ -12423,6 +13147,24 @@ CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) { // is aligned the same as the access width. if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result; + // For narrowing to be valid, it must be the case that the load the + // immediately preceeding memory operation before the store. + if (LD == Chain.getNode()) + ; // ok. + else if (Chain->getOpcode() == ISD::TokenFactor && + SDValue(LD, 1).hasOneUse()) { + // LD has only 1 chain use so they are no indirect dependencies. + bool isOk = false; + for (const SDValue &ChainOp : Chain->op_values()) + if (ChainOp.getNode() == LD) { + isOk = true; + break; + } + if (!isOk) + return Result; + } else + return Result; // Fail. + Result.first = MaskedBytes; Result.second = NotMaskTZ/8; return Result; @@ -12741,12 +13483,6 @@ bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, return false; } -static SDValue peekThroughBitcast(SDValue V) { - while (V.getOpcode() == ISD::BITCAST) - V = V.getOperand(0); - return V; -} - SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores) { SmallVector<SDValue, 8> Chains; @@ -12871,6 +13607,7 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( StoreSDNode *St = cast<StoreSDNode>(StoreNodes[Idx].MemNode); SDValue Val = St->getValue(); + Val = peekThroughBitcast(Val); StoreInt <<= ElementSizeBits; if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) { StoreInt |= C->getAPIntValue() @@ -12903,13 +13640,13 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( FirstInChain->getPointerInfo(), FirstInChain->getAlignment()); } else { // Must be realized as a trunc store - EVT LegalizedStoredValueTy = + EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType()); - unsigned LegalizedStoreSize = LegalizedStoredValueTy.getSizeInBits(); + unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits(); ConstantSDNode *C = cast<ConstantSDNode>(StoredVal); SDValue ExtendedStoreVal = DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL, - LegalizedStoredValueTy); + LegalizedStoredValTy); NewStore = DAG.getTruncStore( NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(), FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/, @@ -12926,10 +13663,11 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( } void DAGCombiner::getStoreMergeCandidates( - StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes) { + StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes, + SDNode *&RootNode) { // This holds the base pointer, index, and the offset in bytes from the base // pointer. - BaseIndexOffset BasePtr = BaseIndexOffset::match(St->getBasePtr(), DAG); + BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG); EVT MemVT = St->getMemoryVT(); SDValue Val = peekThroughBitcast(St->getValue()); @@ -12950,11 +13688,17 @@ void DAGCombiner::getStoreMergeCandidates( EVT LoadVT; if (IsLoadSrc) { auto *Ld = cast<LoadSDNode>(Val); - LBasePtr = BaseIndexOffset::match(Ld->getBasePtr(), DAG); + LBasePtr = BaseIndexOffset::match(Ld, DAG); LoadVT = Ld->getMemoryVT(); // Load and store should be the same type. if (MemVT != LoadVT) return; + // Loads must only have one use. + if (!Ld->hasNUsesOfValue(1, 0)) + return; + // The memory operands must not be volatile. + if (Ld->isVolatile() || Ld->isIndexed()) + return; } auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr, int64_t &Offset) -> bool { @@ -12969,9 +13713,15 @@ void DAGCombiner::getStoreMergeCandidates( return false; // The Load's Base Ptr must also match if (LoadSDNode *OtherLd = dyn_cast<LoadSDNode>(Val)) { - auto LPtr = BaseIndexOffset::match(OtherLd->getBasePtr(), DAG); + auto LPtr = BaseIndexOffset::match(OtherLd, DAG); if (LoadVT != OtherLd->getMemoryVT()) return false; + // Loads must only have one use. + if (!OtherLd->hasNUsesOfValue(1, 0)) + return false; + // The memory operands must not be volatile. + if (OtherLd->isVolatile() || OtherLd->isIndexed()) + return false; if (!(LBasePtr.equalBaseIndex(LPtr, DAG))) return false; } else @@ -12993,7 +13743,7 @@ void DAGCombiner::getStoreMergeCandidates( Val.getOpcode() != ISD::EXTRACT_SUBVECTOR) return false; } - Ptr = BaseIndexOffset::match(Other->getBasePtr(), DAG); + Ptr = BaseIndexOffset::match(Other, DAG); return (BasePtr.equalBaseIndex(Ptr, DAG, Offset)); }; @@ -13013,7 +13763,7 @@ void DAGCombiner::getStoreMergeCandidates( // FIXME: We should be able to climb and // descend TokenFactors to find candidates as well. - SDNode *RootNode = (St->getChain()).getNode(); + RootNode = St->getChain().getNode(); if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(RootNode)) { RootNode = Ldn->getChain().getNode(); @@ -13044,31 +13794,54 @@ void DAGCombiner::getStoreMergeCandidates( // through the chain). Check in parallel by searching up from // non-chain operands of candidates. bool DAGCombiner::checkMergeStoreCandidatesForDependencies( - SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores) { + SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores, + SDNode *RootNode) { // FIXME: We should be able to truncate a full search of // predecessors by doing a BFS and keeping tabs the originating // stores from which worklist nodes come from in a similar way to // TokenFactor simplfication. - SmallPtrSet<const SDNode *, 16> Visited; + SmallPtrSet<const SDNode *, 32> Visited; SmallVector<const SDNode *, 8> Worklist; - unsigned int Max = 8192; + + // RootNode is a predecessor to all candidates so we need not search + // past it. Add RootNode (peeking through TokenFactors). Do not count + // these towards size check. + + Worklist.push_back(RootNode); + while (!Worklist.empty()) { + auto N = Worklist.pop_back_val(); + if (N->getOpcode() == ISD::TokenFactor) { + for (SDValue Op : N->ops()) + Worklist.push_back(Op.getNode()); + } + Visited.insert(N); + } + + // Don't count pruning nodes towards max. + unsigned int Max = 1024 + Visited.size(); // Search Ops of store candidates. for (unsigned i = 0; i < NumStores; ++i) { - SDNode *n = StoreNodes[i].MemNode; - // Potential loops may happen only through non-chain operands - for (unsigned j = 1; j < n->getNumOperands(); ++j) - Worklist.push_back(n->getOperand(j).getNode()); + SDNode *N = StoreNodes[i].MemNode; + // Of the 4 Store Operands: + // * Chain (Op 0) -> We have already considered these + // in candidate selection and can be + // safely ignored + // * Value (Op 1) -> Cycles may happen (e.g. through load chains) + // * Address (Op 2) -> Merged addresses may only vary by a fixed constant + // and so no cycles are possible. + // * (Op 3) -> appears to always be undef. Cannot be source of cycle. + // + // Thus we need only check predecessors of the value operands. + auto *Op = N->getOperand(1).getNode(); + if (Visited.insert(Op).second) + Worklist.push_back(Op); } // Search through DAG. We can stop early if we find a store node. - for (unsigned i = 0; i < NumStores; ++i) { + for (unsigned i = 0; i < NumStores; ++i) if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist, Max)) return false; - // Check if we ended early, failing conservatively if so. - if (Visited.size() >= Max) - return false; - } return true; } @@ -13106,8 +13879,9 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { return false; SmallVector<MemOpLink, 8> StoreNodes; + SDNode *RootNode; // Find potential store merge candidates by searching through chain sub-DAG - getStoreMergeCandidates(St, StoreNodes); + getStoreMergeCandidates(St, StoreNodes, RootNode); // Check if there is anything to merge. if (StoreNodes.size() < 2) @@ -13115,10 +13889,10 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { // Sort the memory operands according to their distance from the // base pointer. - std::sort(StoreNodes.begin(), StoreNodes.end(), - [](MemOpLink LHS, MemOpLink RHS) { - return LHS.OffsetFromBase < RHS.OffsetFromBase; - }); + llvm::sort(StoreNodes.begin(), StoreNodes.end(), + [](MemOpLink LHS, MemOpLink RHS) { + return LHS.OffsetFromBase < RHS.OffsetFromBase; + }); // Store Merge attempts to merge the lowest stores. This generally // works out as if successful, as the remaining stores are checked @@ -13162,178 +13936,191 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { continue; } - // Check that we can merge these candidates without causing a cycle - if (!checkMergeStoreCandidatesForDependencies(StoreNodes, - NumConsecutiveStores)) { - StoreNodes.erase(StoreNodes.begin(), - StoreNodes.begin() + NumConsecutiveStores); - continue; - } - // The node with the lowest store address. LLVMContext &Context = *DAG.getContext(); const DataLayout &DL = DAG.getDataLayout(); // Store the constants into memory as one consecutive store. if (IsConstantSrc) { - LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; - unsigned FirstStoreAS = FirstInChain->getAddressSpace(); - unsigned FirstStoreAlign = FirstInChain->getAlignment(); - unsigned LastLegalType = 1; - unsigned LastLegalVectorType = 1; - bool LastIntegerTrunc = false; - bool NonZero = false; - unsigned FirstZeroAfterNonZero = NumConsecutiveStores; - for (unsigned i = 0; i < NumConsecutiveStores; ++i) { - StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode); - SDValue StoredVal = ST->getValue(); - bool IsElementZero = false; - if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal)) - IsElementZero = C->isNullValue(); - else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal)) - IsElementZero = C->getConstantFPValue()->isNullValue(); - if (IsElementZero) { - if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores) - FirstZeroAfterNonZero = i; - } - NonZero |= !IsElementZero; + while (NumConsecutiveStores >= 2) { + LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; + unsigned FirstStoreAS = FirstInChain->getAddressSpace(); + unsigned FirstStoreAlign = FirstInChain->getAlignment(); + unsigned LastLegalType = 1; + unsigned LastLegalVectorType = 1; + bool LastIntegerTrunc = false; + bool NonZero = false; + unsigned FirstZeroAfterNonZero = NumConsecutiveStores; + for (unsigned i = 0; i < NumConsecutiveStores; ++i) { + StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode); + SDValue StoredVal = ST->getValue(); + bool IsElementZero = false; + if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal)) + IsElementZero = C->isNullValue(); + else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal)) + IsElementZero = C->getConstantFPValue()->isNullValue(); + if (IsElementZero) { + if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores) + FirstZeroAfterNonZero = i; + } + NonZero |= !IsElementZero; - // Find a legal type for the constant store. - unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8; - EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits); - bool IsFast = false; - if (TLI.isTypeLegal(StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, - FirstStoreAlign, &IsFast) && - IsFast) { - LastIntegerTrunc = false; - LastLegalType = i + 1; - // Or check whether a truncstore is legal. - } else if (TLI.getTypeAction(Context, StoreTy) == - TargetLowering::TypePromoteInteger) { - EVT LegalizedStoredValueTy = - TLI.getTypeToTransformTo(Context, StoredVal.getValueType()); - if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValueTy, DAG) && + // Find a legal type for the constant store. + unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8; + EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits); + bool IsFast = false; + + // Break early when size is too large to be legal. + if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits) + break; + + if (TLI.isTypeLegal(StoreTy) && + TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, FirstStoreAlign, &IsFast) && IsFast) { - LastIntegerTrunc = true; + LastIntegerTrunc = false; LastLegalType = i + 1; + // Or check whether a truncstore is legal. + } else if (TLI.getTypeAction(Context, StoreTy) == + TargetLowering::TypePromoteInteger) { + EVT LegalizedStoredValTy = + TLI.getTypeToTransformTo(Context, StoredVal.getValueType()); + if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) && + TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) && + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, + FirstStoreAlign, &IsFast) && + IsFast) { + LastIntegerTrunc = true; + LastLegalType = i + 1; + } } - } - // We only use vectors if the constant is known to be zero or the target - // allows it and the function is not marked with the noimplicitfloat - // attribute. - if ((!NonZero || - TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) && - !NoVectors) { - // Find a legal type for the vector store. - unsigned Elts = (i + 1) * NumMemElts; - EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); - if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) && - TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) && - TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, - FirstStoreAlign, &IsFast) && - IsFast) - LastLegalVectorType = i + 1; + // We only use vectors if the constant is known to be zero or the + // target allows it and the function is not marked with the + // noimplicitfloat attribute. + if ((!NonZero || + TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) && + !NoVectors) { + // Find a legal type for the vector store. + unsigned Elts = (i + 1) * NumMemElts; + EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); + if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) && + TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) && + TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, + FirstStoreAlign, &IsFast) && + IsFast) + LastLegalVectorType = i + 1; + } } - } - bool UseVector = (LastLegalVectorType > LastLegalType) && !NoVectors; - unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType; + bool UseVector = (LastLegalVectorType > LastLegalType) && !NoVectors; + unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType; + + // Check if we found a legal integer type that creates a meaningful + // merge. + if (NumElem < 2) { + // We know that candidate stores are in order and of correct + // shape. While there is no mergeable sequence from the + // beginning one may start later in the sequence. The only + // reason a merge of size N could have failed where another of + // the same size would not have, is if the alignment has + // improved or we've dropped a non-zero value. Drop as many + // candidates as we can here. + unsigned NumSkip = 1; + while ( + (NumSkip < NumConsecutiveStores) && + (NumSkip < FirstZeroAfterNonZero) && + (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) + NumSkip++; + + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); + NumConsecutiveStores -= NumSkip; + continue; + } - // Check if we found a legal integer type that creates a meaningful merge. - if (NumElem < 2) { - // We know that candidate stores are in order and of correct - // shape. While there is no mergeable sequence from the - // beginning one may start later in the sequence. The only - // reason a merge of size N could have failed where another of - // the same size would not have, is if the alignment has - // improved or we've dropped a non-zero value. Drop as many - // candidates as we can here. - unsigned NumSkip = 1; - while ( - (NumSkip < NumConsecutiveStores) && - (NumSkip < FirstZeroAfterNonZero) && - (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) { - NumSkip++; + // Check that we can merge these candidates without causing a cycle. + if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem, + RootNode)) { + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); + NumConsecutiveStores -= NumElem; + continue; } - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); - continue; - } - bool Merged = MergeStoresOfConstantsOrVecElts( - StoreNodes, MemVT, NumElem, true, UseVector, LastIntegerTrunc); - RV |= Merged; + RV |= MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem, true, + UseVector, LastIntegerTrunc); - // Remove merged stores for next iteration. - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); + // Remove merged stores for next iteration. + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); + NumConsecutiveStores -= NumElem; + } continue; } // When extracting multiple vector elements, try to store them // in one vector store rather than a sequence of scalar stores. if (IsExtractVecSrc) { - LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; - unsigned FirstStoreAS = FirstInChain->getAddressSpace(); - unsigned FirstStoreAlign = FirstInChain->getAlignment(); - unsigned NumStoresToMerge = 1; - for (unsigned i = 0; i < NumConsecutiveStores; ++i) { - StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); - SDValue StVal = peekThroughBitcast(St->getValue()); - // This restriction could be loosened. - // Bail out if any stored values are not elements extracted from a - // vector. It should be possible to handle mixed sources, but load - // sources need more careful handling (see the block of code below that - // handles consecutive loads). - if (StVal.getOpcode() != ISD::EXTRACT_VECTOR_ELT && - StVal.getOpcode() != ISD::EXTRACT_SUBVECTOR) - return RV; + // Loop on Consecutive Stores on success. + while (NumConsecutiveStores >= 2) { + LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; + unsigned FirstStoreAS = FirstInChain->getAddressSpace(); + unsigned FirstStoreAlign = FirstInChain->getAlignment(); + unsigned NumStoresToMerge = 1; + for (unsigned i = 0; i < NumConsecutiveStores; ++i) { + // Find a legal type for the vector store. + unsigned Elts = (i + 1) * NumMemElts; + EVT Ty = + EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); + bool IsFast; - // Find a legal type for the vector store. - unsigned Elts = (i + 1) * NumMemElts; - EVT Ty = - EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); - bool IsFast; - if (TLI.isTypeLegal(Ty) && - TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) && - TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, - FirstStoreAlign, &IsFast) && - IsFast) - NumStoresToMerge = i + 1; - } + // Break early when size is too large to be legal. + if (Ty.getSizeInBits() > MaximumLegalStoreInBits) + break; - // Check if we found a legal integer type that creates a meaningful merge. - if (NumStoresToMerge < 2) { - // We know that candidate stores are in order and of correct - // shape. While there is no mergeable sequence from the - // beginning one may start later in the sequence. The only - // reason a merge of size N could have failed where another of - // the same size would not have, is if the alignment has - // improved. Drop as many candidates as we can here. - unsigned NumSkip = 1; - while ((NumSkip < NumConsecutiveStores) && - (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) - NumSkip++; + if (TLI.isTypeLegal(Ty) && + TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) && + TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, + FirstStoreAlign, &IsFast) && + IsFast) + NumStoresToMerge = i + 1; + } - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); - continue; - } + // Check if we found a legal integer type creating a meaningful + // merge. + if (NumStoresToMerge < 2) { + // We know that candidate stores are in order and of correct + // shape. While there is no mergeable sequence from the + // beginning one may start later in the sequence. The only + // reason a merge of size N could have failed where another of + // the same size would not have, is if the alignment has + // improved. Drop as many candidates as we can here. + unsigned NumSkip = 1; + while ( + (NumSkip < NumConsecutiveStores) && + (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) + NumSkip++; + + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); + NumConsecutiveStores -= NumSkip; + continue; + } + + // Check that we can merge these candidates without causing a cycle. + if (!checkMergeStoreCandidatesForDependencies( + StoreNodes, NumStoresToMerge, RootNode)) { + StoreNodes.erase(StoreNodes.begin(), + StoreNodes.begin() + NumStoresToMerge); + NumConsecutiveStores -= NumStoresToMerge; + continue; + } + + RV |= MergeStoresOfConstantsOrVecElts( + StoreNodes, MemVT, NumStoresToMerge, false, true, false); - bool Merged = MergeStoresOfConstantsOrVecElts( - StoreNodes, MemVT, NumStoresToMerge, false, true, false); - if (!Merged) { StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumStoresToMerge); - continue; + NumConsecutiveStores -= NumStoresToMerge; } - // Remove merged stores for next iteration. - StoreNodes.erase(StoreNodes.begin(), - StoreNodes.begin() + NumStoresToMerge); - RV = true; continue; } @@ -13347,26 +14134,13 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { // Find acceptable loads. Loads need to have the same chain (token factor), // must not be zext, volatile, indexed, and they must be consecutive. BaseIndexOffset LdBasePtr; + for (unsigned i = 0; i < NumConsecutiveStores; ++i) { StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); SDValue Val = peekThroughBitcast(St->getValue()); - LoadSDNode *Ld = dyn_cast<LoadSDNode>(Val); - if (!Ld) - break; + LoadSDNode *Ld = cast<LoadSDNode>(Val); - // Loads must only have one use. - if (!Ld->hasNUsesOfValue(1, 0)) - break; - - // The memory operands must not be volatile. - if (Ld->isVolatile() || Ld->isIndexed()) - break; - - // The stored memory type must be the same. - if (Ld->getMemoryVT() != MemVT) - break; - - BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld->getBasePtr(), DAG); + BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG); // If this is not the first ptr that we check. int64_t LdOffset = 0; if (LdBasePtr.getBase().getNode()) { @@ -13382,90 +14156,75 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { LoadNodes.push_back(MemOpLink(Ld, LdOffset)); } - if (LoadNodes.size() < 2) { - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1); - continue; - } + while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) { + // If we have load/store pair instructions and we only have two values, + // don't bother merging. + unsigned RequiredAlignment; + if (LoadNodes.size() == 2 && + TLI.hasPairedLoad(MemVT, RequiredAlignment) && + StoreNodes[0].MemNode->getAlignment() >= RequiredAlignment) { + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2); + LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2); + break; + } + LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; + unsigned FirstStoreAS = FirstInChain->getAddressSpace(); + unsigned FirstStoreAlign = FirstInChain->getAlignment(); + LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode); + unsigned FirstLoadAS = FirstLoad->getAddressSpace(); + unsigned FirstLoadAlign = FirstLoad->getAlignment(); - // If we have load/store pair instructions and we only have two values, - // don't bother merging. - unsigned RequiredAlignment; - if (LoadNodes.size() == 2 && TLI.hasPairedLoad(MemVT, RequiredAlignment) && - StoreNodes[0].MemNode->getAlignment() >= RequiredAlignment) { - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2); - continue; - } - LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; - unsigned FirstStoreAS = FirstInChain->getAddressSpace(); - unsigned FirstStoreAlign = FirstInChain->getAlignment(); - LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode); - unsigned FirstLoadAS = FirstLoad->getAddressSpace(); - unsigned FirstLoadAlign = FirstLoad->getAlignment(); + // Scan the memory operations on the chain and find the first + // non-consecutive load memory address. These variables hold the index in + // the store node array. - // Scan the memory operations on the chain and find the first - // non-consecutive load memory address. These variables hold the index in - // the store node array. - unsigned LastConsecutiveLoad = 1; - // This variable refers to the size and not index in the array. - unsigned LastLegalVectorType = 1; - unsigned LastLegalIntegerType = 1; - bool isDereferenceable = true; - bool DoIntegerTruncate = false; - StartAddress = LoadNodes[0].OffsetFromBase; - SDValue FirstChain = FirstLoad->getChain(); - for (unsigned i = 1; i < LoadNodes.size(); ++i) { - // All loads must share the same chain. - if (LoadNodes[i].MemNode->getChain() != FirstChain) - break; + unsigned LastConsecutiveLoad = 1; - int64_t CurrAddress = LoadNodes[i].OffsetFromBase; - if (CurrAddress - StartAddress != (ElementSizeBytes * i)) - break; - LastConsecutiveLoad = i; - - if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable()) - isDereferenceable = false; - - // Find a legal type for the vector store. - unsigned Elts = (i + 1) * NumMemElts; - EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); - - bool IsFastSt, IsFastLd; - if (TLI.isTypeLegal(StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, - FirstStoreAlign, &IsFastSt) && - IsFastSt && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, - FirstLoadAlign, &IsFastLd) && - IsFastLd) { - LastLegalVectorType = i + 1; - } + // This variable refers to the size and not index in the array. + unsigned LastLegalVectorType = 1; + unsigned LastLegalIntegerType = 1; + bool isDereferenceable = true; + bool DoIntegerTruncate = false; + StartAddress = LoadNodes[0].OffsetFromBase; + SDValue FirstChain = FirstLoad->getChain(); + for (unsigned i = 1; i < LoadNodes.size(); ++i) { + // All loads must share the same chain. + if (LoadNodes[i].MemNode->getChain() != FirstChain) + break; + + int64_t CurrAddress = LoadNodes[i].OffsetFromBase; + if (CurrAddress - StartAddress != (ElementSizeBytes * i)) + break; + LastConsecutiveLoad = i; - // Find a legal type for the integer store. - unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8; - StoreTy = EVT::getIntegerVT(Context, SizeInBits); - if (TLI.isTypeLegal(StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, - FirstStoreAlign, &IsFastSt) && - IsFastSt && - TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, - FirstLoadAlign, &IsFastLd) && - IsFastLd) { - LastLegalIntegerType = i + 1; - DoIntegerTruncate = false; - // Or check whether a truncstore and extload is legal. - } else if (TLI.getTypeAction(Context, StoreTy) == - TargetLowering::TypePromoteInteger) { - EVT LegalizedStoredValueTy = TLI.getTypeToTransformTo(Context, StoreTy); - if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValueTy, DAG) && - TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValueTy, - StoreTy) && - TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValueTy, - StoreTy) && - TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValueTy, StoreTy) && + if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable()) + isDereferenceable = false; + + // Find a legal type for the vector store. + unsigned Elts = (i + 1) * NumMemElts; + EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); + + // Break early when size is too large to be legal. + if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits) + break; + + bool IsFastSt, IsFastLd; + if (TLI.isTypeLegal(StoreTy) && + TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, + FirstStoreAlign, &IsFastSt) && + IsFastSt && + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, + FirstLoadAlign, &IsFastLd) && + IsFastLd) { + LastLegalVectorType = i + 1; + } + + // Find a legal type for the integer store. + unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8; + StoreTy = EVT::getIntegerVT(Context, SizeInBits); + if (TLI.isTypeLegal(StoreTy) && + TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, FirstStoreAlign, &IsFastSt) && IsFastSt && @@ -13473,105 +14232,140 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { FirstLoadAlign, &IsFastLd) && IsFastLd) { LastLegalIntegerType = i + 1; - DoIntegerTruncate = true; + DoIntegerTruncate = false; + // Or check whether a truncstore and extload is legal. + } else if (TLI.getTypeAction(Context, StoreTy) == + TargetLowering::TypePromoteInteger) { + EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy); + if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) && + TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) && + TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, + StoreTy) && + TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, + StoreTy) && + TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) && + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, + FirstStoreAlign, &IsFastSt) && + IsFastSt && + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, + FirstLoadAlign, &IsFastLd) && + IsFastLd) { + LastLegalIntegerType = i + 1; + DoIntegerTruncate = true; + } } } - } - // Only use vector types if the vector type is larger than the integer type. - // If they are the same, use integers. - bool UseVectorTy = LastLegalVectorType > LastLegalIntegerType && !NoVectors; - unsigned LastLegalType = - std::max(LastLegalVectorType, LastLegalIntegerType); - - // We add +1 here because the LastXXX variables refer to location while - // the NumElem refers to array/index size. - unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1); - NumElem = std::min(LastLegalType, NumElem); - - if (NumElem < 2) { - // We know that candidate stores are in order and of correct - // shape. While there is no mergeable sequence from the - // beginning one may start later in the sequence. The only - // reason a merge of size N could have failed where another of - // the same size would not have is if the alignment or either - // the load or store has improved. Drop as many candidates as we - // can here. - unsigned NumSkip = 1; - while ((NumSkip < LoadNodes.size()) && - (LoadNodes[NumSkip].MemNode->getAlignment() <= FirstLoadAlign) && - (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) - NumSkip++; - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); - continue; - } + // Only use vector types if the vector type is larger than the integer + // type. If they are the same, use integers. + bool UseVectorTy = + LastLegalVectorType > LastLegalIntegerType && !NoVectors; + unsigned LastLegalType = + std::max(LastLegalVectorType, LastLegalIntegerType); - // Find if it is better to use vectors or integers to load and store - // to memory. - EVT JointMemOpVT; - if (UseVectorTy) { - // Find a legal type for the vector store. - unsigned Elts = NumElem * NumMemElts; - JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); - } else { - unsigned SizeInBits = NumElem * ElementSizeBytes * 8; - JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits); - } - - SDLoc LoadDL(LoadNodes[0].MemNode); - SDLoc StoreDL(StoreNodes[0].MemNode); - - // The merged loads are required to have the same incoming chain, so - // using the first's chain is acceptable. - - SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem); - AddToWorklist(NewStoreChain.getNode()); - - MachineMemOperand::Flags MMOFlags = isDereferenceable ? - MachineMemOperand::MODereferenceable: - MachineMemOperand::MONone; - - SDValue NewLoad, NewStore; - if (UseVectorTy || !DoIntegerTruncate) { - NewLoad = DAG.getLoad(JointMemOpVT, LoadDL, FirstLoad->getChain(), - FirstLoad->getBasePtr(), - FirstLoad->getPointerInfo(), FirstLoadAlign, - MMOFlags); - NewStore = DAG.getStore(NewStoreChain, StoreDL, NewLoad, - FirstInChain->getBasePtr(), - FirstInChain->getPointerInfo(), FirstStoreAlign); - } else { // This must be the truncstore/extload case - EVT ExtendedTy = - TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT); - NewLoad = - DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy, FirstLoad->getChain(), - FirstLoad->getBasePtr(), FirstLoad->getPointerInfo(), - JointMemOpVT, FirstLoadAlign, MMOFlags); - NewStore = DAG.getTruncStore(NewStoreChain, StoreDL, NewLoad, - FirstInChain->getBasePtr(), - FirstInChain->getPointerInfo(), JointMemOpVT, - FirstInChain->getAlignment(), - FirstInChain->getMemOperand()->getFlags()); - } - - // Transfer chain users from old loads to the new load. - for (unsigned i = 0; i < NumElem; ++i) { - LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), - SDValue(NewLoad.getNode(), 1)); - } - - // Replace the all stores with the new store. Recursively remove - // corresponding value if its no longer used. - for (unsigned i = 0; i < NumElem; ++i) { - SDValue Val = StoreNodes[i].MemNode->getOperand(1); - CombineTo(StoreNodes[i].MemNode, NewStore); - if (Val.getNode()->use_empty()) - recursivelyDeleteUnusedNodes(Val.getNode()); - } - - RV = true; - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); + // We add +1 here because the LastXXX variables refer to location while + // the NumElem refers to array/index size. + unsigned NumElem = + std::min(NumConsecutiveStores, LastConsecutiveLoad + 1); + NumElem = std::min(LastLegalType, NumElem); + + if (NumElem < 2) { + // We know that candidate stores are in order and of correct + // shape. While there is no mergeable sequence from the + // beginning one may start later in the sequence. The only + // reason a merge of size N could have failed where another of + // the same size would not have is if the alignment or either + // the load or store has improved. Drop as many candidates as we + // can here. + unsigned NumSkip = 1; + while ((NumSkip < LoadNodes.size()) && + (LoadNodes[NumSkip].MemNode->getAlignment() <= FirstLoadAlign) && + (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) + NumSkip++; + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); + LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip); + NumConsecutiveStores -= NumSkip; + continue; + } + + // Check that we can merge these candidates without causing a cycle. + if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem, + RootNode)) { + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); + LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem); + NumConsecutiveStores -= NumElem; + continue; + } + + // Find if it is better to use vectors or integers to load and store + // to memory. + EVT JointMemOpVT; + if (UseVectorTy) { + // Find a legal type for the vector store. + unsigned Elts = NumElem * NumMemElts; + JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); + } else { + unsigned SizeInBits = NumElem * ElementSizeBytes * 8; + JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits); + } + + SDLoc LoadDL(LoadNodes[0].MemNode); + SDLoc StoreDL(StoreNodes[0].MemNode); + + // The merged loads are required to have the same incoming chain, so + // using the first's chain is acceptable. + + SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem); + AddToWorklist(NewStoreChain.getNode()); + + MachineMemOperand::Flags MMOFlags = + isDereferenceable ? MachineMemOperand::MODereferenceable + : MachineMemOperand::MONone; + + SDValue NewLoad, NewStore; + if (UseVectorTy || !DoIntegerTruncate) { + NewLoad = + DAG.getLoad(JointMemOpVT, LoadDL, FirstLoad->getChain(), + FirstLoad->getBasePtr(), FirstLoad->getPointerInfo(), + FirstLoadAlign, MMOFlags); + NewStore = DAG.getStore( + NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(), + FirstInChain->getPointerInfo(), FirstStoreAlign); + } else { // This must be the truncstore/extload case + EVT ExtendedTy = + TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT); + NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy, + FirstLoad->getChain(), FirstLoad->getBasePtr(), + FirstLoad->getPointerInfo(), JointMemOpVT, + FirstLoadAlign, MMOFlags); + NewStore = DAG.getTruncStore(NewStoreChain, StoreDL, NewLoad, + FirstInChain->getBasePtr(), + FirstInChain->getPointerInfo(), + JointMemOpVT, FirstInChain->getAlignment(), + FirstInChain->getMemOperand()->getFlags()); + } + + // Transfer chain users from old loads to the new load. + for (unsigned i = 0; i < NumElem; ++i) { + LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), + SDValue(NewLoad.getNode(), 1)); + } + + // Replace the all stores with the new store. Recursively remove + // corresponding value if its no longer used. + for (unsigned i = 0; i < NumElem; ++i) { + SDValue Val = StoreNodes[i].MemNode->getOperand(1); + CombineTo(StoreNodes[i].MemNode, NewStore); + if (Val.getNode()->use_empty()) + recursivelyDeleteUnusedNodes(Val.getNode()); + } + + RV = true; + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); + LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem); + NumConsecutiveStores -= NumElem; + } } return RV; } @@ -13713,13 +14507,14 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { // Try to infer better alignment information than the store already has. if (OptLevel != CodeGenOpt::None && ST->isUnindexed()) { if (unsigned Align = DAG.InferPtrAlignment(Ptr)) { - if (Align > ST->getAlignment()) { + if (Align > ST->getAlignment() && ST->getSrcValueOffset() % Align == 0) { SDValue NewStore = DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(), ST->getMemoryVT(), Align, ST->getMemOperand()->getFlags(), ST->getAAInfo()); - if (NewStore.getNode() != N) - return CombineTo(ST, NewStore, true); + // NewStore will always be N as we are only refining the alignment + assert(NewStore.getNode() == N); + (void)NewStore; } } } @@ -13783,30 +14578,30 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { } } - // Deal with elidable overlapping chained stores. - if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) - if (OptLevel != CodeGenOpt::None && ST->isUnindexed() && - ST1->isUnindexed() && !ST1->isVolatile() && ST1->hasOneUse() && - !ST1->getBasePtr().isUndef() && !ST->isVolatile()) { - BaseIndexOffset STBasePtr = BaseIndexOffset::match(ST->getBasePtr(), DAG); - BaseIndexOffset ST1BasePtr = - BaseIndexOffset::match(ST1->getBasePtr(), DAG); - unsigned STBytes = ST->getMemoryVT().getStoreSize(); - unsigned ST1Bytes = ST1->getMemoryVT().getStoreSize(); - int64_t PtrDiff; - // If this is a store who's preceeding store to a subset of the same - // memory and no one other node is chained to that store we can - // effectively drop the store. Do not remove stores to undef as they may - // be used as data sinks. - - if (((ST->getBasePtr() == ST1->getBasePtr()) && - (ST->getValue() == ST1->getValue())) || - (STBasePtr.equalBaseIndex(ST1BasePtr, DAG, PtrDiff) && - (0 <= PtrDiff) && (PtrDiff + ST1Bytes <= STBytes))) { + if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) { + if (ST->isUnindexed() && !ST->isVolatile() && ST1->isUnindexed() && + !ST1->isVolatile() && ST1->getBasePtr() == Ptr && + ST->getMemoryVT() == ST1->getMemoryVT()) { + // If this is a store followed by a store with the same value to the same + // location, then the store is dead/noop. + if (ST1->getValue() == Value) { + // The store is dead, remove it. + return Chain; + } + + // If this is a store who's preceeding store to the same location + // and no one other node is chained to that store we can effectively + // drop the store. Do not remove stores to undef as they may be used as + // data sinks. + if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() && + !ST1->getBasePtr().isUndef()) { + // ST1 is fully overwritten and can be elided. Combine with it's chain + // value. CombineTo(ST1, ST1->getChain()); - return SDValue(N, 0); + return SDValue(); } } + } // If this is an FP_ROUND or TRUNC followed by a store, fold this into a // truncating store. We can do this even if this is already a truncstore. @@ -14201,6 +14996,10 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { SDValue EltNo = N->getOperand(1); ConstantSDNode *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo); + // extract_vector_elt of out-of-bounds element -> UNDEF + if (ConstEltNo && ConstEltNo->getAPIntValue().uge(VT.getVectorNumElements())) + return DAG.getUNDEF(NVT); + // extract_vector_elt (build_vector x, y), 1 -> y if (ConstEltNo && InVec.getOpcode() == ISD::BUILD_VECTOR && @@ -14286,6 +15085,23 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { } } + // If only EXTRACT_VECTOR_ELT nodes use the source vector we can + // simplify it based on the (valid) extraction indices. + if (llvm::all_of(InVec->uses(), [&](SDNode *Use) { + return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT && + Use->getOperand(0) == InVec && + isa<ConstantSDNode>(Use->getOperand(1)); + })) { + APInt DemandedElts = APInt::getNullValue(VT.getVectorNumElements()); + for (SDNode *Use : InVec->uses()) { + auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1)); + if (CstElt->getAPIntValue().ult(VT.getVectorNumElements())) + DemandedElts.setBit(CstElt->getZExtValue()); + } + if (SimplifyDemandedVectorElts(InVec, DemandedElts, true)) + return SDValue(N, 0); + } + bool BCNumEltsChanged = false; EVT ExtVT = VT.getVectorElementType(); EVT LVT = ExtVT; @@ -14492,7 +15308,10 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) { assert(VecVT.getSizeInBits() == VT.getSizeInBits() && "Invalid vector size"); // Check if the new vector type is legal. - if (!isTypeLegal(VecVT)) return SDValue(); + if (!isTypeLegal(VecVT) || + (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) && + TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))) + return SDValue(); // Make the new BUILD_VECTOR. SDValue BV = DAG.getBuildVector(VecVT, DL, Ops); @@ -14739,12 +15558,16 @@ SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) { } // Not an undef or zero. If the input is something other than an - // EXTRACT_VECTOR_ELT with a constant index, bail out. + // EXTRACT_VECTOR_ELT with an in-range constant index, bail out. if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT || !isa<ConstantSDNode>(Op.getOperand(1))) return SDValue(); SDValue ExtractedFromVec = Op.getOperand(0); + APInt ExtractIdx = cast<ConstantSDNode>(Op.getOperand(1))->getAPIntValue(); + if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements())) + return SDValue(); + // All inputs must have the same element type as the output. if (VT.getVectorElementType() != ExtractedFromVec.getValueType().getVectorElementType()) @@ -14900,6 +15723,54 @@ SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) { return Shuffles[0]; } +// Try to turn a build vector of zero extends of extract vector elts into a +// a vector zero extend and possibly an extract subvector. +// TODO: Support sign extend or any extend? +// TODO: Allow undef elements? +// TODO: Don't require the extracts to start at element 0. +SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) { + if (LegalOperations) + return SDValue(); + + EVT VT = N->getValueType(0); + + SDValue Op0 = N->getOperand(0); + auto checkElem = [&](SDValue Op) -> int64_t { + if (Op.getOpcode() == ISD::ZERO_EXTEND && + Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT && + Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0)) + if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1))) + return C->getZExtValue(); + return -1; + }; + + // Make sure the first element matches + // (zext (extract_vector_elt X, C)) + int64_t Offset = checkElem(Op0); + if (Offset < 0) + return SDValue(); + + unsigned NumElems = N->getNumOperands(); + SDValue In = Op0.getOperand(0).getOperand(0); + EVT InSVT = In.getValueType().getScalarType(); + EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems); + + // Don't create an illegal input type after type legalization. + if (LegalTypes && !TLI.isTypeLegal(InVT)) + return SDValue(); + + // Ensure all the elements come from the same vector and are adjacent. + for (unsigned i = 1; i != NumElems; ++i) { + if ((Offset + i) != checkElem(N->getOperand(i))) + return SDValue(); + } + + SDLoc DL(N); + In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In, + Op0.getOperand(0).getOperand(1)); + return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, In); +} + SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) { EVT VT = N->getValueType(0); @@ -14907,6 +15778,32 @@ SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) { if (ISD::allOperandsUndef(N)) return DAG.getUNDEF(VT); + // If this is a splat of a bitcast from another vector, change to a + // concat_vector. + // For example: + // (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) -> + // (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X)))) + // + // If X is a build_vector itself, the concat can become a larger build_vector. + // TODO: Maybe this is useful for non-splat too? + if (!LegalOperations) { + if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) { + Splat = peekThroughBitcast(Splat); + EVT SrcVT = Splat.getValueType(); + if (SrcVT.isVector()) { + unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements(); + EVT NewVT = EVT::getVectorVT(*DAG.getContext(), + SrcVT.getVectorElementType(), NumElts); + if (!LegalTypes || TLI.isTypeLegal(NewVT)) { + SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat); + SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), + NewVT, Ops); + return DAG.getBitcast(VT, Concat); + } + } + } + } + // Check if we can express BUILD VECTOR via subvector extract. if (!LegalTypes && (N->getNumOperands() > 1)) { SDValue Op0 = N->getOperand(0); @@ -14936,6 +15833,9 @@ SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) { Op0.getOperand(0), Op0.getOperand(1)); } + if (SDValue V = convertBuildVecZextToZext(N)) + return V; + if (SDValue V = reduceBuildVecExtToExtBuildVec(N)) return V; @@ -15125,6 +16025,10 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { if (!SclTy.isFloatingPoint() && !SclTy.isInteger()) return SDValue(); + // Bail out if the vector size is not a multiple of the scalar size. + if (VT.getSizeInBits() % SclTy.getSizeInBits()) + return SDValue(); + unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits(); if (VNTNumElms < 2) return SDValue(); @@ -15403,13 +16307,22 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) { // Only do this if we won't split any elements. if (ExtractSize % EltSize == 0) { unsigned NumElems = ExtractSize / EltSize; - EVT ExtractVT = EVT::getVectorVT(*DAG.getContext(), - InVT.getVectorElementType(), NumElems); - if ((!LegalOperations || - TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT)) && + EVT EltVT = InVT.getVectorElementType(); + EVT ExtractVT = NumElems == 1 ? EltVT : + EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems); + if ((Level < AfterLegalizeDAG || + (NumElems == 1 || + TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) && (!LegalTypes || TLI.isTypeLegal(ExtractVT))) { unsigned IdxVal = (Idx->getZExtValue() * NVT.getScalarSizeInBits()) / EltSize; + if (NumElems == 1) { + SDValue Src = V->getOperand(IdxVal); + if (EltVT != Src.getValueType()) + Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src); + + return DAG.getBitcast(NVT, Src); + } // Extract the pieces from the original build_vector. SDValue BuildVec = DAG.getBuildVector(ExtractVT, SDLoc(N), @@ -15451,122 +16364,8 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) { if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG)) return NarrowBOp; - return SDValue(); -} - -static SDValue simplifyShuffleOperandRecursively(SmallBitVector &UsedElements, - SDValue V, SelectionDAG &DAG) { - SDLoc DL(V); - EVT VT = V.getValueType(); - - switch (V.getOpcode()) { - default: - return V; - - case ISD::CONCAT_VECTORS: { - EVT OpVT = V->getOperand(0).getValueType(); - int OpSize = OpVT.getVectorNumElements(); - SmallBitVector OpUsedElements(OpSize, false); - bool FoundSimplification = false; - SmallVector<SDValue, 4> NewOps; - NewOps.reserve(V->getNumOperands()); - for (int i = 0, NumOps = V->getNumOperands(); i < NumOps; ++i) { - SDValue Op = V->getOperand(i); - bool OpUsed = false; - for (int j = 0; j < OpSize; ++j) - if (UsedElements[i * OpSize + j]) { - OpUsedElements[j] = true; - OpUsed = true; - } - NewOps.push_back( - OpUsed ? simplifyShuffleOperandRecursively(OpUsedElements, Op, DAG) - : DAG.getUNDEF(OpVT)); - FoundSimplification |= Op == NewOps.back(); - OpUsedElements.reset(); - } - if (FoundSimplification) - V = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, NewOps); - return V; - } - - case ISD::INSERT_SUBVECTOR: { - SDValue BaseV = V->getOperand(0); - SDValue SubV = V->getOperand(1); - auto *IdxN = dyn_cast<ConstantSDNode>(V->getOperand(2)); - if (!IdxN) - return V; - - int SubSize = SubV.getValueType().getVectorNumElements(); - int Idx = IdxN->getZExtValue(); - bool SubVectorUsed = false; - SmallBitVector SubUsedElements(SubSize, false); - for (int i = 0; i < SubSize; ++i) - if (UsedElements[i + Idx]) { - SubVectorUsed = true; - SubUsedElements[i] = true; - UsedElements[i + Idx] = false; - } - - // Now recurse on both the base and sub vectors. - SDValue SimplifiedSubV = - SubVectorUsed - ? simplifyShuffleOperandRecursively(SubUsedElements, SubV, DAG) - : DAG.getUNDEF(SubV.getValueType()); - SDValue SimplifiedBaseV = simplifyShuffleOperandRecursively(UsedElements, BaseV, DAG); - if (SimplifiedSubV != SubV || SimplifiedBaseV != BaseV) - V = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, - SimplifiedBaseV, SimplifiedSubV, V->getOperand(2)); - return V; - } - } -} - -static SDValue simplifyShuffleOperands(ShuffleVectorSDNode *SVN, SDValue N0, - SDValue N1, SelectionDAG &DAG) { - EVT VT = SVN->getValueType(0); - int NumElts = VT.getVectorNumElements(); - SmallBitVector N0UsedElements(NumElts, false), N1UsedElements(NumElts, false); - for (int M : SVN->getMask()) - if (M >= 0 && M < NumElts) - N0UsedElements[M] = true; - else if (M >= NumElts) - N1UsedElements[M - NumElts] = true; - - SDValue S0 = simplifyShuffleOperandRecursively(N0UsedElements, N0, DAG); - SDValue S1 = simplifyShuffleOperandRecursively(N1UsedElements, N1, DAG); - if (S0 == N0 && S1 == N1) - return SDValue(); - - return DAG.getVectorShuffle(VT, SDLoc(SVN), S0, S1, SVN->getMask()); -} - -static SDValue simplifyShuffleMask(ShuffleVectorSDNode *SVN, SDValue N0, - SDValue N1, SelectionDAG &DAG) { - auto isUndefElt = [](SDValue V, int Idx) { - // TODO - handle more cases as required. - if (V.getOpcode() == ISD::BUILD_VECTOR) - return V.getOperand(Idx).isUndef(); - if (V.getOpcode() == ISD::SCALAR_TO_VECTOR) - return (Idx != 0) || V.getOperand(0).isUndef(); - return false; - }; - - EVT VT = SVN->getValueType(0); - unsigned NumElts = VT.getVectorNumElements(); - - bool Changed = false; - SmallVector<int, 8> NewMask; - for (unsigned i = 0; i != NumElts; ++i) { - int Idx = SVN->getMaskElt(i); - if ((0 <= Idx && Idx < (int)NumElts && isUndefElt(N0, Idx)) || - ((int)NumElts < Idx && isUndefElt(N1, Idx - NumElts))) { - Changed = true; - Idx = -1; - } - NewMask.push_back(Idx); - } - if (Changed) - return DAG.getVectorShuffle(VT, SDLoc(SVN), N0, N1, NewMask); + if (SimplifyDemandedVectorElts(SDValue(N, 0))) + return SDValue(N, 0); return SDValue(); } @@ -16013,10 +16812,6 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask); } - // Simplify shuffle mask if a referenced element is UNDEF. - if (SDValue V = simplifyShuffleMask(SVN, N0, N1, DAG)) - return V; - if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG)) return InsElt; @@ -16077,11 +16872,9 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { } } - // There are various patterns used to build up a vector from smaller vectors, - // subvectors, or elements. Scan chains of these and replace unused insertions - // or components with undef. - if (SDValue S = simplifyShuffleOperands(SVN, N0, N1, DAG)) - return S; + // Simplify source operands based on shuffle mask. + if (SimplifyDemandedVectorElts(SDValue(N, 0))) + return SDValue(N, 0); // Match shuffles that can be converted to any_vector_extend_in_reg. if (SDValue V = combineShuffleToVectorExtend(SVN, DAG, TLI, LegalOperations, LegalTypes)) @@ -16394,7 +17187,9 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR && N1.getOperand(0).getOperand(1) == N2 && N1.getOperand(0).getOperand(0).getValueType().getVectorNumElements() == - VT.getVectorNumElements()) { + VT.getVectorNumElements() && + N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() == + VT.getSizeInBits()) { return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0)); } @@ -16405,10 +17200,11 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) { SDValue CN0 = N0.getOperand(0); SDValue CN1 = N1.getOperand(0); - if (CN0.getValueType().getVectorElementType() == - CN1.getValueType().getVectorElementType() && - CN0.getValueType().getVectorNumElements() == - VT.getVectorNumElements()) { + EVT CN0VT = CN0.getValueType(); + EVT CN1VT = CN1.getValueType(); + if (CN0VT.isVector() && CN1VT.isVector() && + CN0VT.getVectorElementType() == CN1VT.getVectorElementType() && + CN0VT.getVectorNumElements() == VT.getVectorNumElements()) { SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), CN0.getValueType(), CN0, CN1, N2); return DAG.getBitcast(VT, NewINSERT); @@ -16663,14 +17459,14 @@ bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS, const ConstantFPSDNode *Zero = nullptr; if (TheSelect->getOpcode() == ISD::SELECT_CC) { - CC = dyn_cast<CondCodeSDNode>(TheSelect->getOperand(4))->get(); + CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get(); CmpLHS = TheSelect->getOperand(0); Zero = isConstOrConstSplatFP(TheSelect->getOperand(1)); } else { // SELECT or VSELECT SDValue Cmp = TheSelect->getOperand(0); if (Cmp.getOpcode() == ISD::SETCC) { - CC = dyn_cast<CondCodeSDNode>(Cmp.getOperand(2))->get(); + CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get(); CmpLHS = Cmp.getOperand(0); Zero = isConstOrConstSplatFP(Cmp.getOperand(1)); } @@ -16888,24 +17684,6 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, return !SCCC->isNullValue() ? N2 : N3; } - // Check to see if we can simplify the select into an fabs node - if (ConstantFPSDNode *CFP = dyn_cast<ConstantFPSDNode>(N1)) { - // Allow either -0.0 or 0.0 - if (CFP->isZero()) { - // select (setg[te] X, +/-0.0), X, fneg(X) -> fabs - if ((CC == ISD::SETGE || CC == ISD::SETGT) && - N0 == N2 && N3.getOpcode() == ISD::FNEG && - N2 == N3.getOperand(0)) - return DAG.getNode(ISD::FABS, DL, VT, N0); - - // select (setl[te] X, +/-0.0), fneg(X), X -> fabs - if ((CC == ISD::SETLT || CC == ISD::SETLE) && - N0 == N3 && N2.getOpcode() == ISD::FNEG && - N2.getOperand(0) == N3) - return DAG.getNode(ISD::FABS, DL, VT, N3); - } - } - // Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)" // where "tmp" is a constant pool entry containing an array with 1.0 and 2.0 // in it. This is a win when the constant is not otherwise available because @@ -17383,19 +18161,34 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal); if (!Reciprocal) { - // Unfortunately, Est is now NaN if the input was exactly 0.0. - // Select out this case and force the answer to 0.0. + // The estimate is now completely wrong if the input was exactly 0.0 or + // possibly a denormal. Force the answer to 0.0 for those cases. EVT VT = Op.getValueType(); SDLoc DL(Op); - - SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); EVT CCVT = getSetCCResultType(VT); - SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); - AddToWorklist(ZeroCmp.getNode()); - - Est = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, - ZeroCmp, FPZero, Est); - AddToWorklist(Est.getNode()); + ISD::NodeType SelOpcode = VT.isVector() ? ISD::VSELECT : ISD::SELECT; + const Function &F = DAG.getMachineFunction().getFunction(); + Attribute Denorms = F.getFnAttribute("denormal-fp-math"); + if (Denorms.getValueAsString().equals("ieee")) { + // fabs(X) < SmallestNormal ? 0.0 : Est + const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT); + APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem); + SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT); + SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); + SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op); + SDValue IsDenorm = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT); + Est = DAG.getNode(SelOpcode, DL, VT, IsDenorm, FPZero, Est); + AddToWorklist(Fabs.getNode()); + AddToWorklist(IsDenorm.getNode()); + AddToWorklist(Est.getNode()); + } else { + // X == 0.0 ? 0.0 : Est + SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); + SDValue IsZero = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); + Est = DAG.getNode(SelOpcode, DL, VT, IsZero, FPZero, Est); + AddToWorklist(IsZero.getNode()); + AddToWorklist(Est.getNode()); + } } } return Est; @@ -17433,44 +18226,46 @@ bool DAGCombiner::isAlias(LSBaseSDNode *Op0, LSBaseSDNode *Op1) const { unsigned NumBytes1 = Op1->getMemoryVT().getStoreSize(); // Check for BaseIndexOffset matching. - BaseIndexOffset BasePtr0 = BaseIndexOffset::match(Op0->getBasePtr(), DAG); - BaseIndexOffset BasePtr1 = BaseIndexOffset::match(Op1->getBasePtr(), DAG); + BaseIndexOffset BasePtr0 = BaseIndexOffset::match(Op0, DAG); + BaseIndexOffset BasePtr1 = BaseIndexOffset::match(Op1, DAG); int64_t PtrDiff; - if (BasePtr0.equalBaseIndex(BasePtr1, DAG, PtrDiff)) - return !((NumBytes0 <= PtrDiff) || (PtrDiff + NumBytes1 <= 0)); - - // If both BasePtr0 and BasePtr1 are FrameIndexes, we will not be - // able to calculate their relative offset if at least one arises - // from an alloca. However, these allocas cannot overlap and we - // can infer there is no alias. - if (auto *A = dyn_cast<FrameIndexSDNode>(BasePtr0.getBase())) - if (auto *B = dyn_cast<FrameIndexSDNode>(BasePtr1.getBase())) { - MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); - // If the base are the same frame index but the we couldn't find a - // constant offset, (indices are different) be conservative. - if (A != B && (!MFI.isFixedObjectIndex(A->getIndex()) || - !MFI.isFixedObjectIndex(B->getIndex()))) - return false; - } - - bool IsFI0 = isa<FrameIndexSDNode>(BasePtr0.getBase()); - bool IsFI1 = isa<FrameIndexSDNode>(BasePtr1.getBase()); - bool IsGV0 = isa<GlobalAddressSDNode>(BasePtr0.getBase()); - bool IsGV1 = isa<GlobalAddressSDNode>(BasePtr1.getBase()); - bool IsCV0 = isa<ConstantPoolSDNode>(BasePtr0.getBase()); - bool IsCV1 = isa<ConstantPoolSDNode>(BasePtr1.getBase()); + if (BasePtr0.getBase().getNode() && BasePtr1.getBase().getNode()) { + if (BasePtr0.equalBaseIndex(BasePtr1, DAG, PtrDiff)) + return !((NumBytes0 <= PtrDiff) || (PtrDiff + NumBytes1 <= 0)); + + // If both BasePtr0 and BasePtr1 are FrameIndexes, we will not be + // able to calculate their relative offset if at least one arises + // from an alloca. However, these allocas cannot overlap and we + // can infer there is no alias. + if (auto *A = dyn_cast<FrameIndexSDNode>(BasePtr0.getBase())) + if (auto *B = dyn_cast<FrameIndexSDNode>(BasePtr1.getBase())) { + MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); + // If the base are the same frame index but the we couldn't find a + // constant offset, (indices are different) be conservative. + if (A != B && (!MFI.isFixedObjectIndex(A->getIndex()) || + !MFI.isFixedObjectIndex(B->getIndex()))) + return false; + } - // If of mismatched base types or checkable indices we can check - // they do not alias. - if ((BasePtr0.getIndex() == BasePtr1.getIndex() || (IsFI0 != IsFI1) || - (IsGV0 != IsGV1) || (IsCV0 != IsCV1)) && - (IsFI0 || IsGV0 || IsCV0) && (IsFI1 || IsGV1 || IsCV1)) - return false; + bool IsFI0 = isa<FrameIndexSDNode>(BasePtr0.getBase()); + bool IsFI1 = isa<FrameIndexSDNode>(BasePtr1.getBase()); + bool IsGV0 = isa<GlobalAddressSDNode>(BasePtr0.getBase()); + bool IsGV1 = isa<GlobalAddressSDNode>(BasePtr1.getBase()); + bool IsCV0 = isa<ConstantPoolSDNode>(BasePtr0.getBase()); + bool IsCV1 = isa<ConstantPoolSDNode>(BasePtr1.getBase()); + + // If of mismatched base types or checkable indices we can check + // they do not alias. + if ((BasePtr0.getIndex() == BasePtr1.getIndex() || (IsFI0 != IsFI1) || + (IsGV0 != IsGV1) || (IsCV0 != IsCV1)) && + (IsFI0 || IsGV0 || IsCV0) && (IsFI1 || IsGV1 || IsCV1)) + return false; + } - // If we know required SrcValue1 and SrcValue2 have relatively large alignment - // compared to the size and offset of the access, we may be able to prove they - // do not alias. This check is conservative for now to catch cases created by - // splitting vector types. + // If we know required SrcValue1 and SrcValue2 have relatively large + // alignment compared to the size and offset of the access, we may be able + // to prove they do not alias. This check is conservative for now to catch + // cases created by splitting vector types. int64_t SrcValOffset0 = Op0->getSrcValueOffset(); int64_t SrcValOffset1 = Op1->getSrcValueOffset(); unsigned OrigAlignment0 = Op0->getOriginalAlignment(); @@ -17480,8 +18275,8 @@ bool DAGCombiner::isAlias(LSBaseSDNode *Op0, LSBaseSDNode *Op1) const { int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0; int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1; - // There is no overlap between these relatively aligned accesses of similar - // size. Return no alias. + // There is no overlap between these relatively aligned accesses of + // similar size. Return no alias. if ((OffAlign0 + NumBytes0) <= OffAlign1 || (OffAlign1 + NumBytes1) <= OffAlign0) return false; @@ -17644,7 +18439,7 @@ bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) { // This holds the base pointer, index, and the offset in bytes from the base // pointer. - BaseIndexOffset BasePtr = BaseIndexOffset::match(St->getBasePtr(), DAG); + BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG); // We must have a base and an offset. if (!BasePtr.getBase().getNode()) @@ -17670,7 +18465,7 @@ bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) { break; // Find the base pointer and offset for this memory node. - BaseIndexOffset Ptr = BaseIndexOffset::match(Index->getBasePtr(), DAG); + BaseIndexOffset Ptr = BaseIndexOffset::match(Index, DAG); // Check that the base pointer is the same as the original one. if (!BasePtr.equalBaseIndex(Ptr, DAG)) @@ -17696,7 +18491,7 @@ bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) { Index = nullptr; break; } - } // end while + }// end while } // At this point, ChainedStores lists all of the Store nodes |