diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2020-07-26 19:36:28 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2020-07-26 19:36:28 +0000 |
commit | cfca06d7963fa0909f90483b42a6d7d194d01e08 (patch) | |
tree | 209fb2a2d68f8f277793fc8df46c753d31bc853b /llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | |
parent | 706b4fc47bbc608932d3b491ae19a3b9cde9497b (diff) |
Notes
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 3368 |
1 files changed, 2132 insertions, 1236 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index e5bc08b9280ab..f14b3dba4f318 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/CodeGen/DAGCombine.h" #include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/MachineFrameInfo.h" @@ -124,17 +125,29 @@ static cl::opt<unsigned> StoreMergeDependenceLimit( cl::desc("Limit the number of times for the same StoreNode and RootNode " "to bail out in store merging dependence check")); +static cl::opt<bool> EnableReduceLoadOpStoreWidth( + "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(true), + cl::desc("DAG cominber enable reducing the width of load/op/store " + "sequence")); + +static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore( + "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(true), + cl::desc("DAG cominber enable load/<replace bytes>/store with " + "a narrower store")); + namespace { class DAGCombiner { SelectionDAG &DAG; const TargetLowering &TLI; + const SelectionDAGTargetInfo *STI; CombineLevel Level; CodeGenOpt::Level OptLevel; bool LegalDAG = false; bool LegalOperations = false; bool LegalTypes = false; bool ForCodeSize; + bool DisableGenericCombines; /// Worklist of all of the nodes that need to be simplified. /// @@ -222,9 +235,11 @@ namespace { public: DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOpt::Level OL) - : DAG(D), TLI(D.getTargetLoweringInfo()), Level(BeforeLegalizeTypes), - OptLevel(OL), AA(AA) { + : DAG(D), TLI(D.getTargetLoweringInfo()), + STI(D.getSubtarget().getSelectionDAGInfo()), + Level(BeforeLegalizeTypes), OptLevel(OL), AA(AA) { ForCodeSize = DAG.shouldOptForSize(); + DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel); MaximumLegalStoreInBits = 0; // We use the minimum store size here, since that's all we can guarantee @@ -307,23 +322,34 @@ namespace { } bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) { - EVT VT = Op.getValueType(); - unsigned NumElts = VT.isVector() ? VT.getVectorNumElements() : 1; - APInt DemandedElts = APInt::getAllOnesValue(NumElts); - return SimplifyDemandedBits(Op, DemandedBits, DemandedElts); + TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations); + KnownBits Known; + if (!TLI.SimplifyDemandedBits(Op, DemandedBits, Known, TLO, 0, false)) + return false; + + // Revisit the node. + AddToWorklist(Op.getNode()); + + CommitTargetLoweringOpt(TLO); + return true; } /// Check the specified vector node value to see if it can be simplified or /// if things it uses can be simplified as it only uses some of the /// elements. If so, return true. bool SimplifyDemandedVectorElts(SDValue Op) { + // TODO: For now just pretend it cannot be simplified. + if (Op.getValueType().isScalableVector()) + return false; + unsigned NumElts = Op.getValueType().getVectorNumElements(); APInt DemandedElts = APInt::getAllOnesValue(NumElts); return SimplifyDemandedVectorElts(Op, DemandedElts); } bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits, - const APInt &DemandedElts); + const APInt &DemandedElts, + bool AssumeSingleUse = false); bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts, bool AssumeSingleUse = false); @@ -429,11 +455,13 @@ namespace { SDValue visitZERO_EXTEND(SDNode *N); SDValue visitANY_EXTEND(SDNode *N); SDValue visitAssertExt(SDNode *N); + SDValue visitAssertAlign(SDNode *N); SDValue visitSIGN_EXTEND_INREG(SDNode *N); SDValue visitSIGN_EXTEND_VECTOR_INREG(SDNode *N); SDValue visitZERO_EXTEND_VECTOR_INREG(SDNode *N); SDValue visitTRUNCATE(SDNode *N); SDValue visitBITCAST(SDNode *N); + SDValue visitFREEZE(SDNode *N); SDValue visitBUILD_PAIR(SDNode *N); SDValue visitFADD(SDNode *N); SDValue visitFSUB(SDNode *N); @@ -522,9 +550,8 @@ namespace { SDValue rebuildSetCC(SDValue N); bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS, - SDValue &CC) const; + SDValue &CC, bool MatchStrict = false) const; bool isOneUseSetCC(SDValue N) const; - bool isCheaperToUseNegatedFPOps(SDValue X, SDValue Y); SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp, unsigned HiOp); @@ -553,6 +580,10 @@ namespace { SDValue InnerPos, SDValue InnerNeg, unsigned PosOpcode, unsigned NegOpcode, const SDLoc &DL); + SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg, + SDValue InnerPos, SDValue InnerNeg, + unsigned PosOpcode, unsigned NegOpcode, + const SDLoc &DL); SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL); SDValue MatchLoadCombine(SDNode *N); SDValue MatchStoreCombine(StoreSDNode *N); @@ -562,6 +593,7 @@ namespace { SDValue TransformFPLoadStorePair(SDNode *N); SDValue convertBuildVecZextToZext(SDNode *N); SDValue reduceBuildVecExtToExtBuildVec(SDNode *N); + SDValue reduceBuildVecTruncToBitCast(SDNode *N); SDValue reduceBuildVecToShuffle(SDNode *N); SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N, ArrayRef<int> VectorMask, SDValue VecIn1, @@ -606,6 +638,19 @@ namespace { : MemNode(N), OffsetFromBase(Offset) {} }; + // Classify the origin of a stored value. + enum class StoreSource { Unknown, Constant, Extract, Load }; + StoreSource getStoreSource(SDValue StoreVal) { + if (isa<ConstantSDNode>(StoreVal) || isa<ConstantFPSDNode>(StoreVal)) + return StoreSource::Constant; + if (StoreVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT || + StoreVal.getOpcode() == ISD::EXTRACT_SUBVECTOR) + return StoreSource::Extract; + if (isa<LoadSDNode>(StoreVal)) + return StoreSource::Load; + return StoreSource::Unknown; + } + /// This is a helper function for visitMUL to check the profitability /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2). /// MulNode is the original multiply, AddNode is (add x, c1), @@ -633,43 +678,66 @@ namespace { /// can be combined into narrow loads. bool BackwardsPropagateMask(SDNode *N); - /// Helper function for MergeConsecutiveStores which merges the - /// component store chains. + /// Helper function for mergeConsecutiveStores which merges the component + /// store chains. SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores); - /// This is a helper function for MergeConsecutiveStores. When the - /// source elements of the consecutive stores are all constants or - /// all extracted vector elements, try to merge them into one - /// larger store introducing bitcasts if necessary. \return True - /// if a merged store was created. - bool MergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes, + /// This is a helper function for mergeConsecutiveStores. When the source + /// elements of the consecutive stores are all constants or all extracted + /// vector elements, try to merge them into one larger store introducing + /// bitcasts if necessary. \return True if a merged store was created. + bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores, bool IsConstantSrc, bool UseVector, bool UseTrunc); - /// This is a helper function for MergeConsecutiveStores. Stores - /// that potentially may be merged with St are placed in - /// StoreNodes. RootNode is a chain predecessor to all store - /// candidates. + /// This is a helper function for mergeConsecutiveStores. Stores that + /// potentially may be merged with St are placed in StoreNodes. RootNode is + /// a chain predecessor to all store candidates. void getStoreMergeCandidates(StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes, SDNode *&Root); - /// Helper function for MergeConsecutiveStores. Checks if - /// candidate stores have indirect dependency through their - /// operands. RootNode is the predecessor to all stores calculated - /// by getStoreMergeCandidates and is used to prune the dependency check. - /// \return True if safe to merge. + /// Helper function for mergeConsecutiveStores. Checks if candidate stores + /// have indirect dependency through their operands. RootNode is the + /// predecessor to all stores calculated by getStoreMergeCandidates and is + /// used to prune the dependency check. \return True if safe to merge. bool checkMergeStoreCandidatesForDependencies( SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores, SDNode *RootNode); + /// This is a helper function for mergeConsecutiveStores. Given a list of + /// store candidates, find the first N that are consecutive in memory. + /// Returns 0 if there are not at least 2 consecutive stores to try merging. + unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes, + int64_t ElementSizeBytes) const; + + /// This is a helper function for mergeConsecutiveStores. It is used for + /// store chains that are composed entirely of constant values. + bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes, + unsigned NumConsecutiveStores, + EVT MemVT, SDNode *Root, bool AllowVectors); + + /// This is a helper function for mergeConsecutiveStores. It is used for + /// store chains that are composed entirely of extracted vector elements. + /// When extracting multiple vector elements, try to store them in one + /// vector store rather than a sequence of scalar stores. + bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes, + unsigned NumConsecutiveStores, EVT MemVT, + SDNode *Root); + + /// This is a helper function for mergeConsecutiveStores. It is used for + /// store chains that are composed entirely of loaded values. + bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes, + unsigned NumConsecutiveStores, EVT MemVT, + SDNode *Root, bool AllowVectors, + bool IsNonTemporalStore, bool IsNonTemporalLoad); + /// Merge consecutive store operations into a wide store. /// This optimization uses wide integers or vectors when possible. - /// \return number of stores that were merged into a merged store (the - /// affected nodes are stored as a prefix in \p StoreNodes). - bool MergeConsecutiveStores(StoreSDNode *St); + /// \return true if stores were merged. + bool mergeConsecutiveStores(StoreSDNode *St); /// Try to transform a truncation where C is a constant: /// (trunc (and X, C)) -> (and (trunc X), (trunc C)) @@ -814,7 +882,7 @@ static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) { // the appropriate nodes based on the type of node we are checking. This // simplifies life a bit for the callers. bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS, - SDValue &CC) const { + SDValue &CC, bool MatchStrict) const { if (N.getOpcode() == ISD::SETCC) { LHS = N.getOperand(0); RHS = N.getOperand(1); @@ -822,6 +890,15 @@ bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS, return true; } + if (MatchStrict && + (N.getOpcode() == ISD::STRICT_FSETCC || + N.getOpcode() == ISD::STRICT_FSETCCS)) { + LHS = N.getOperand(1); + RHS = N.getOperand(2); + CC = N.getOperand(3); + return true; + } + if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N.getOperand(2).getNode()) || !TLI.isConstFalseVal(N.getOperand(3).getNode())) @@ -886,6 +963,13 @@ static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) { ISD::isBuildVectorOfConstantFPSDNodes(V.getNode()); } +// Determine if this an indexed load with an opaque target constant index. +static bool canSplitIdx(LoadSDNode *LD) { + return MaySplitLoadIndex && + (LD->getOperand(2).getOpcode() != ISD::TargetConstant || + !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque()); +} + bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc, const SDLoc &DL, SDValue N0, @@ -951,14 +1035,11 @@ SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, if (N0.getOpcode() != Opc) return SDValue(); - // Don't reassociate reductions. - if (N0->getFlags().hasVectorReduction()) - return SDValue(); - - if (SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1))) { - if (SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N1)) { + if (DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1))) { + if (DAG.isConstantIntBuildVectorOrConstantInt(N1)) { // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2)) - if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, C1, C2)) + if (SDValue OpNode = + DAG.FoldConstantArithmetic(Opc, DL, VT, {N0.getOperand(1), N1})) return DAG.getNode(Opc, DL, VT, N0.getOperand(0), OpNode); return SDValue(); } @@ -978,9 +1059,6 @@ SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, SDValue N1, SDNodeFlags Flags) { assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative."); - // Don't reassociate reductions. - if (Flags.hasVectorReduction()) - return SDValue(); // Floating-point reassociation is not allowed without loose FP math. if (N0.getValueType().isFloatingPoint() || @@ -1029,6 +1107,12 @@ SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo, void DAGCombiner:: CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) { + // Replace the old value with the new one. + ++NodesCombined; + LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG); + dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG); + dbgs() << '\n'); + // Replace all uses. If any nodes become isomorphic to other nodes and // are deleted, make sure to remove them from our worklist. WorklistRemover DeadNodes(*this); @@ -1047,21 +1131,17 @@ CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) { /// Check the specified integer node value to see if it can be simplified or if /// things it uses can be simplified by bit propagation. If so, return true. bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits, - const APInt &DemandedElts) { + const APInt &DemandedElts, + bool AssumeSingleUse) { TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations); KnownBits Known; - if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO)) + if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, 0, + AssumeSingleUse)) return false; // Revisit the node. AddToWorklist(Op.getNode()); - // Replace the old value with the new one. - ++NodesCombined; - LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG); - dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG); - dbgs() << '\n'); - CommitTargetLoweringOpt(TLO); return true; } @@ -1081,12 +1161,6 @@ bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op, // Revisit the node. AddToWorklist(Op.getNode()); - // Replace the old value with the new one. - ++NodesCombined; - LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG); - dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG); - dbgs() << '\n'); - CommitTargetLoweringOpt(TLO); return true; } @@ -1210,8 +1284,11 @@ SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) { SDValue RV = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1)); - // We are always replacing N0/N1's use in N and only need - // additional replacements if there are additional uses. + // We are always replacing N0/N1's use in N and only need additional + // replacements if there are additional uses. + // Note: We are checking uses of the *nodes* (SDNode) rather than values + // (SDValue) here because the node may reference multiple values + // (for example, the chain value of a load node). Replace0 &= !N0->hasOneUse(); Replace1 &= (N0 != N1) && !N1->hasOneUse(); @@ -1561,6 +1638,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::ANY_EXTEND: return visitANY_EXTEND(N); case ISD::AssertSext: case ISD::AssertZext: return visitAssertExt(N); + case ISD::AssertAlign: return visitAssertAlign(N); case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N); case ISD::SIGN_EXTEND_VECTOR_INREG: return visitSIGN_EXTEND_VECTOR_INREG(N); case ISD::ZERO_EXTEND_VECTOR_INREG: return visitZERO_EXTEND_VECTOR_INREG(N); @@ -1610,6 +1688,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::LIFETIME_END: return visitLIFETIME_END(N); case ISD::FP_TO_FP16: return visitFP_TO_FP16(N); case ISD::FP16_TO_FP: return visitFP16_TO_FP(N); + case ISD::FREEZE: return visitFREEZE(N); case ISD::VECREDUCE_FADD: case ISD::VECREDUCE_FMUL: case ISD::VECREDUCE_ADD: @@ -1628,7 +1707,9 @@ SDValue DAGCombiner::visit(SDNode *N) { } SDValue DAGCombiner::combine(SDNode *N) { - SDValue RV = visit(N); + SDValue RV; + if (!DisableGenericCombines) + RV = visit(N); // If nothing happened, try a target-specific DAG combine. if (!RV.getNode()) { @@ -2046,12 +2127,11 @@ static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) { // We need a constant operand for the add/sub, and the other operand is a // logical shift right: add (srl), C or sub C, (srl). - // TODO - support non-uniform vector amounts. bool IsAdd = N->getOpcode() == ISD::ADD; SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0); SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1); - ConstantSDNode *C = isConstOrConstSplat(ConstantOp); - if (!C || ShiftOp.getOpcode() != ISD::SRL) + if (!DAG.isConstantIntBuildVectorOrConstantInt(ConstantOp) || + ShiftOp.getOpcode() != ISD::SRL) return SDValue(); // The shift must be of a 'not' value. @@ -2072,8 +2152,11 @@ static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) { SDLoc DL(N); auto ShOpcode = IsAdd ? ISD::SRA : ISD::SRL; SDValue NewShift = DAG.getNode(ShOpcode, DL, VT, Not.getOperand(0), ShAmt); - APInt NewC = IsAdd ? C->getAPIntValue() + 1 : C->getAPIntValue() - 1; - return DAG.getNode(ISD::ADD, DL, VT, NewShift, DAG.getConstant(NewC, DL, VT)); + if (SDValue NewC = + DAG.FoldConstantArithmetic(IsAdd ? ISD::ADD : ISD::SUB, DL, VT, + {ConstantOp, DAG.getConstant(1, DL, VT)})) + return DAG.getNode(ISD::ADD, DL, VT, NewShift, NewC); + return SDValue(); } /// Try to fold a node that behaves like an ADD (note that N isn't necessarily @@ -2109,8 +2192,7 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { if (!DAG.isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::ADD, DL, VT, N1, N0); // fold (add c1, c2) -> c1+c2 - return DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, N0.getNode(), - N1.getNode()); + return DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1}); } // fold (add x, 0) -> x @@ -2121,8 +2203,8 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { // fold ((A-c1)+c2) -> (A+(c2-c1)) if (N0.getOpcode() == ISD::SUB && isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) { - SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, N1.getNode(), - N0.getOperand(1).getNode()); + SDValue Sub = + DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N0.getOperand(1)}); assert(Sub && "Constant folding failed"); return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub); } @@ -2130,8 +2212,8 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { // fold ((c1-A)+c2) -> (c1+c2)-A if (N0.getOpcode() == ISD::SUB && isConstantOrConstantVector(N0.getOperand(0), /* NoOpaque */ true)) { - SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, N1.getNode(), - N0.getOperand(0).getNode()); + SDValue Add = + DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N0.getOperand(0)}); assert(Add && "Constant folding failed"); return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1)); } @@ -2152,13 +2234,14 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { } } - // Undo the add -> or combine to merge constant offsets from a frame index. + // Fold (add (or x, c0), c1) -> (add x, (c0 + c1)) if (or x, c0) is + // equivalent to (add x, c0). if (N0.getOpcode() == ISD::OR && - isa<FrameIndexSDNode>(N0.getOperand(0)) && - isa<ConstantSDNode>(N0.getOperand(1)) && + isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true) && DAG.haveNoCommonBitsSet(N0.getOperand(0), N0.getOperand(1))) { - SDValue Add0 = DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(1)); - return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add0); + if (SDValue Add0 = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, + {N1, N0.getOperand(1)})) + return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add0); } } @@ -2317,6 +2400,23 @@ SDValue DAGCombiner::visitADD(SDNode *N) { DAG.haveNoCommonBitsSet(N0, N1)) return DAG.getNode(ISD::OR, DL, VT, N0, N1); + // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)). + if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) { + APInt C0 = N0->getConstantOperandAPInt(0); + APInt C1 = N1->getConstantOperandAPInt(0); + return DAG.getVScale(DL, VT, C0 + C1); + } + + // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2) + if ((N0.getOpcode() == ISD::ADD) && + (N0.getOperand(1).getOpcode() == ISD::VSCALE) && + (N1.getOpcode() == ISD::VSCALE)) { + auto VS0 = N0.getOperand(1)->getConstantOperandAPInt(0); + auto VS1 = N1->getConstantOperandAPInt(0); + auto VS = DAG.getVScale(DL, VT, VS0 + VS1); + return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS); + } + return SDValue(); } @@ -2347,8 +2447,7 @@ SDValue DAGCombiner::visitADDSAT(SDNode *N) { if (!DAG.isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(Opcode, DL, VT, N1, N0); // fold (add_sat c1, c2) -> c3 - return DAG.FoldConstantArithmetic(Opcode, DL, VT, N0.getNode(), - N1.getNode()); + return DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}); } // fold (add_sat x, 0) -> x @@ -2968,12 +3067,10 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { // FIXME: Refactor this and xor and other similar operations together. if (N0 == N1) return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations); - if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && - DAG.isConstantIntBuildVectorOrConstantInt(N1)) { - // fold (sub c1, c2) -> c1-c2 - return DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, N0.getNode(), - N1.getNode()); - } + + // fold (sub c1, c2) -> c3 + if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1})) + return C; if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; @@ -3040,8 +3137,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { if (N0.getOpcode() == ISD::ADD && isConstantOrConstantVector(N1, /* NoOpaques */ true) && isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) { - SDValue NewC = DAG.FoldConstantArithmetic( - ISD::SUB, DL, VT, N0.getOperand(1).getNode(), N1.getNode()); + SDValue NewC = + DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0.getOperand(1), N1}); assert(NewC && "Constant folding failed"); return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC); } @@ -3051,8 +3148,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { SDValue N11 = N1.getOperand(1); if (isConstantOrConstantVector(N0, /* NoOpaques */ true) && isConstantOrConstantVector(N11, /* NoOpaques */ true)) { - SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, N0.getNode(), - N11.getNode()); + SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11}); assert(NewC && "Constant folding failed"); return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0)); } @@ -3062,8 +3158,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { if (N0.getOpcode() == ISD::SUB && isConstantOrConstantVector(N1, /* NoOpaques */ true) && isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) { - SDValue NewC = DAG.FoldConstantArithmetic( - ISD::ADD, DL, VT, N0.getOperand(1).getNode(), N1.getNode()); + SDValue NewC = + DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0.getOperand(1), N1}); assert(NewC && "Constant folding failed"); return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC); } @@ -3072,8 +3168,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { if (N0.getOpcode() == ISD::SUB && isConstantOrConstantVector(N1, /* NoOpaques */ true) && isConstantOrConstantVector(N0.getOperand(0), /* NoOpaques */ true)) { - SDValue NewC = DAG.FoldConstantArithmetic( - ISD::SUB, DL, VT, N0.getOperand(0).getNode(), N1.getNode()); + SDValue NewC = + DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0.getOperand(0), N1}); assert(NewC && "Constant folding failed"); return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1)); } @@ -3244,6 +3340,12 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { } } + // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C)) + if (N1.getOpcode() == ISD::VSCALE) { + APInt IntVal = N1.getConstantOperandAPInt(0); + return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal)); + } + // Prefer an add for more folding potential and possibly better codegen: // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1) if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) { @@ -3294,12 +3396,9 @@ SDValue DAGCombiner::visitSUBSAT(SDNode *N) { if (N0 == N1) return DAG.getConstant(0, DL, VT); - if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && - DAG.isConstantIntBuildVectorOrConstantInt(N1)) { - // fold (sub_sat c1, c2) -> c3 - return DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, N0.getNode(), - N1.getNode()); - } + // fold (sub_sat c1, c2) -> c3 + if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1})) + return C; // fold (sub_sat x, 0) -> x if (isNullConstant(N1)) @@ -3435,30 +3534,20 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { if (N0.isUndef() || N1.isUndef()) return DAG.getConstant(0, SDLoc(N), VT); - bool N0IsConst = false; bool N1IsConst = false; bool N1IsOpaqueConst = false; - bool N0IsOpaqueConst = false; - APInt ConstValue0, ConstValue1; + APInt ConstValue1; + // fold vector ops if (VT.isVector()) { if (SDValue FoldedVOp = SimplifyVBinOp(N)) return FoldedVOp; - N0IsConst = ISD::isConstantSplatVector(N0.getNode(), ConstValue0); N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1); - assert((!N0IsConst || - ConstValue0.getBitWidth() == VT.getScalarSizeInBits()) && - "Splat APInt should be element width"); assert((!N1IsConst || ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) && "Splat APInt should be element width"); } else { - N0IsConst = isa<ConstantSDNode>(N0); - if (N0IsConst) { - ConstValue0 = cast<ConstantSDNode>(N0)->getAPIntValue(); - N0IsOpaqueConst = cast<ConstantSDNode>(N0)->isOpaque(); - } N1IsConst = isa<ConstantSDNode>(N1); if (N1IsConst) { ConstValue1 = cast<ConstantSDNode>(N1)->getAPIntValue(); @@ -3467,17 +3556,18 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { } // fold (mul c1, c2) -> c1*c2 - if (N0IsConst && N1IsConst && !N0IsOpaqueConst && !N1IsOpaqueConst) - return DAG.FoldConstantArithmetic(ISD::MUL, SDLoc(N), VT, - N0.getNode(), N1.getNode()); + if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, SDLoc(N), VT, {N0, N1})) + return C; // canonicalize constant to RHS (vector doesn't have to splat) if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && !DAG.isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0); + // fold (mul x, 0) -> 0 if (N1IsConst && ConstValue1.isNullValue()) return N1; + // fold (mul x, 1) -> x if (N1IsConst && ConstValue1.isOneValue()) return N0; @@ -3491,6 +3581,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0); } + // fold (mul x, (1 << c)) -> x << c if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) && DAG.isKnownToBeAPowerOfTwo(N1) && @@ -3501,6 +3592,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT); return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc); } + // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2()) { unsigned Log2Val = (-ConstValue1).logBase2(); @@ -3589,6 +3681,14 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1)); + // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)). + if (N0.getOpcode() == ISD::VSCALE) + if (ConstantSDNode *NC1 = isConstOrConstSplat(N1)) { + APInt C0 = N0.getConstantOperandAPInt(0); + APInt C1 = NC1->getAPIntValue(); + return DAG.getVScale(SDLoc(N), VT, C0 * C1); + } + // reassociate mul if (SDValue RMUL = reassociateOps(ISD::MUL, SDLoc(N), N0, N1, N->getFlags())) return RMUL; @@ -3746,13 +3846,14 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { SDLoc DL(N); // fold (sdiv c1, c2) -> c1/c2 - ConstantSDNode *N0C = isConstOrConstSplat(N0); ConstantSDNode *N1C = isConstOrConstSplat(N1); - if (N0C && N1C && !N0C->isOpaque() && !N1C->isOpaque()) - return DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, N0C, N1C); + if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1})) + return C; + // fold (sdiv X, -1) -> 0-X if (N1C && N1C->isAllOnesValue()) return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0); + // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0) if (N1C && N1C->getAPIntValue().isMinSignedValue()) return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ), @@ -3890,12 +3991,10 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { SDLoc DL(N); // fold (udiv c1, c2) -> c1/c2 - ConstantSDNode *N0C = isConstOrConstSplat(N0); ConstantSDNode *N1C = isConstOrConstSplat(N1); - if (N0C && N1C) - if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, - N0C, N1C)) - return Folded; + if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1})) + return C; + // fold (udiv X, -1) -> select(X == -1, 1, 0) if (N1C && N1C->getAPIntValue().isAllOnesValue()) return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ), @@ -3988,11 +4087,10 @@ SDValue DAGCombiner::visitREM(SDNode *N) { SDLoc DL(N); // fold (rem c1, c2) -> c1%c2 - ConstantSDNode *N0C = isConstOrConstSplat(N0); ConstantSDNode *N1C = isConstOrConstSplat(N1); - if (N0C && N1C) - if (SDValue Folded = DAG.FoldConstantArithmetic(Opcode, DL, VT, N0C, N1C)) - return Folded; + if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1})) + return C; + // fold (urem X, -1) -> select(X == -1, 0, x) if (!isSigned && N1C && N1C->getAPIntValue().isAllOnesValue()) return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ), @@ -4088,7 +4186,7 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) { // If the type twice as wide is legal, transform the mulhs to a wider multiply // plus a shift. - if (VT.isSimple() && !VT.isVector()) { + if (!TLI.isMulhCheaperThanMulShift(VT) && VT.isSimple() && !VT.isVector()) { MVT Simple = VT.getSimpleVT(); unsigned SimpleSize = Simple.getSizeInBits(); EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2); @@ -4144,7 +4242,7 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) { // If the type twice as wide is legal, transform the mulhu to a wider multiply // plus a shift. - if (VT.isSimple() && !VT.isVector()) { + if (!TLI.isMulhCheaperThanMulShift(VT) && VT.isSimple() && !VT.isVector()) { MVT Simple = VT.getSimpleVT(); unsigned SimpleSize = Simple.getSizeInBits(); EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2); @@ -4317,6 +4415,7 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); + unsigned Opcode = N->getOpcode(); // fold vector ops if (VT.isVector()) @@ -4324,19 +4423,16 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) { return FoldedVOp; // fold operation with constant operands. - ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); - ConstantSDNode *N1C = getAsNonOpaqueConstant(N1); - if (N0C && N1C) - return DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, N0C, N1C); + if (SDValue C = DAG.FoldConstantArithmetic(Opcode, SDLoc(N), VT, {N0, N1})) + return C; // canonicalize constant to RHS if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && - !DAG.isConstantIntBuildVectorOrConstantInt(N1)) + !DAG.isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0); // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX. // Only do this if the current op isn't legal and the flipped is. - unsigned Opcode = N->getOpcode(); if (!TLI.isOperationLegal(Opcode, VT) && (N0.isUndef() || DAG.SignBitIsZero(N0)) && (N1.isUndef() || DAG.SignBitIsZero(N1))) { @@ -4825,11 +4921,16 @@ bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST, return false; // Ensure that this isn't going to produce an unsupported memory access. - if (ShAmt && - !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT, - LDST->getAddressSpace(), ShAmt / 8, - LDST->getMemOperand()->getFlags())) - return false; + if (ShAmt) { + assert(ShAmt % 8 == 0 && "ShAmt is byte offset"); + const unsigned ByteShAmt = ShAmt / 8; + const Align LDSTAlign = LDST->getAlign(); + const Align NarrowAlign = commonAlignment(LDSTAlign, ByteShAmt); + if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT, + LDST->getAddressSpace(), NarrowAlign, + LDST->getMemOperand()->getFlags())) + return false; + } // It's not possible to generate a constant of extended or untyped type. EVT PtrType = LDST->getBasePtr().getValueType(); @@ -5174,17 +5275,19 @@ SDValue DAGCombiner::visitAND(SDNode *N) { } // fold (and c1, c2) -> c1&c2 - ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); ConstantSDNode *N1C = isConstOrConstSplat(N1); - if (N0C && N1C && !N1C->isOpaque()) - return DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, N0C, N1C); + if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, {N0, N1})) + return C; + // canonicalize constant to RHS if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && !DAG.isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0); + // fold (and x, -1) -> x if (isAllOnesConstant(N1)) return N0; + // if (and x, c) is known to be zero, return 0 unsigned BitWidth = VT.getScalarSizeInBits(); if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), @@ -5654,6 +5757,48 @@ static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) { return false; } +// Match this pattern: +// (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff)) +// And rewrite this to: +// (rotr (bswap A), 16) +static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI, + SelectionDAG &DAG, SDNode *N, SDValue N0, + SDValue N1, EVT VT, EVT ShiftAmountTy) { + assert(N->getOpcode() == ISD::OR && VT == MVT::i32 && + "MatchBSwapHWordOrAndAnd: expecting i32"); + if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT)) + return SDValue(); + if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND) + return SDValue(); + // TODO: this is too restrictive; lifting this restriction requires more tests + if (!N0->hasOneUse() || !N1->hasOneUse()) + return SDValue(); + ConstantSDNode *Mask0 = isConstOrConstSplat(N0.getOperand(1)); + ConstantSDNode *Mask1 = isConstOrConstSplat(N1.getOperand(1)); + if (!Mask0 || !Mask1) + return SDValue(); + if (Mask0->getAPIntValue() != 0xff00ff00 || + Mask1->getAPIntValue() != 0x00ff00ff) + return SDValue(); + SDValue Shift0 = N0.getOperand(0); + SDValue Shift1 = N1.getOperand(0); + if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL) + return SDValue(); + ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(Shift0.getOperand(1)); + ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(Shift1.getOperand(1)); + if (!ShiftAmt0 || !ShiftAmt1) + return SDValue(); + if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8) + return SDValue(); + if (Shift0.getOperand(0) != Shift1.getOperand(0)) + return SDValue(); + + SDLoc DL(N); + SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0)); + SDValue ShAmt = DAG.getConstant(16, DL, ShiftAmountTy); + return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt); +} + /// Match a 32-bit packed halfword bswap. That is /// ((x & 0x000000ff) << 8) | /// ((x & 0x0000ff00) >> 8) | @@ -5670,6 +5815,16 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) { if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT)) return SDValue(); + if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT, + getShiftAmountTy(VT))) + return BSwap; + + // Try again with commuted operands. + if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT, + getShiftAmountTy(VT))) + return BSwap; + + // Look for either // (or (bswaphpair), (bswaphpair)) // (or (or (bswaphpair), (and)), (and)) @@ -5875,17 +6030,19 @@ SDValue DAGCombiner::visitOR(SDNode *N) { } // fold (or c1, c2) -> c1|c2 - ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); - if (N0C && N1C && !N1C->isOpaque()) - return DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, N0C, N1C); + if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, {N0, N1})) + return C; + // canonicalize constant to RHS if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && !DAG.isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0); + // fold (or x, 0) -> x if (isNullConstant(N1)) return N0; + // fold (or x, -1) -> -1 if (isAllOnesConstant(N1)) return N1; @@ -5920,8 +6077,8 @@ SDValue DAGCombiner::visitOR(SDNode *N) { }; if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() && ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) { - if (SDValue COR = DAG.FoldConstantArithmetic( - ISD::OR, SDLoc(N1), VT, N1.getNode(), N0.getOperand(1).getNode())) { + if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT, + {N1, N0.getOperand(1)})) { SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1); AddToWorklist(IOR.getNode()); return DAG.getNode(ISD::AND, SDLoc(N), VT, COR, IOR); @@ -6020,6 +6177,7 @@ static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift, ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1)); // (add v v) -> (shl v 1) + // TODO: Should this be a general DAG canonicalization? if (OppShift.getOpcode() == ISD::SRL && OppShiftCst && ExtractFrom.getOpcode() == ISD::ADD && ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) && @@ -6192,8 +6350,12 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize, // EltSize & Mask == NegC & Mask // // (because "x & Mask" is a truncation and distributes through subtraction). + // + // We also need to account for a potential truncation of NegOp1 if the amount + // has already been legalized to a shift amount type. APInt Width; - if (Pos == NegOp1) + if ((Pos == NegOp1) || + (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(0))) Width = NegC->getAPIntValue(); // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC. @@ -6246,19 +6408,91 @@ SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, return SDValue(); } +// A subroutine of MatchRotate used once we have found an OR of two opposite +// shifts of N0 + N1. If Neg == <operand size> - Pos then the OR reduces +// to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the +// former being preferred if supported. InnerPos and InnerNeg are Pos and +// Neg with outer conversions stripped away. +// TODO: Merge with MatchRotatePosNeg. +SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, + SDValue Neg, SDValue InnerPos, + SDValue InnerNeg, unsigned PosOpcode, + unsigned NegOpcode, const SDLoc &DL) { + EVT VT = N0.getValueType(); + unsigned EltBits = VT.getScalarSizeInBits(); + + // fold (or (shl x0, (*ext y)), + // (srl x1, (*ext (sub 32, y)))) -> + // (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y)) + // + // fold (or (shl x0, (*ext (sub 32, y))), + // (srl x1, (*ext y))) -> + // (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y)) + if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG)) { + bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT); + return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1, + HasPos ? Pos : Neg); + } + + // Matching the shift+xor cases, we can't easily use the xor'd shift amount + // so for now just use the PosOpcode case if its legal. + // TODO: When can we use the NegOpcode case? + if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) { + auto IsBinOpImm = [](SDValue Op, unsigned BinOpc, unsigned Imm) { + if (Op.getOpcode() != BinOpc) + return false; + ConstantSDNode *Cst = isConstOrConstSplat(Op.getOperand(1)); + return Cst && (Cst->getAPIntValue() == Imm); + }; + + // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31))) + // -> (fshl x0, x1, y) + if (IsBinOpImm(N1, ISD::SRL, 1) && + IsBinOpImm(InnerNeg, ISD::XOR, EltBits - 1) && + InnerPos == InnerNeg.getOperand(0) && + TLI.isOperationLegalOrCustom(ISD::FSHL, VT)) { + return DAG.getNode(ISD::FSHL, DL, VT, N0, N1.getOperand(0), Pos); + } + + // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y)) + // -> (fshr x0, x1, y) + if (IsBinOpImm(N0, ISD::SHL, 1) && + IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) && + InnerNeg == InnerPos.getOperand(0) && + TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) { + return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg); + } + + // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y)) + // -> (fshr x0, x1, y) + // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization? + if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N0.getOperand(1) && + IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) && + InnerNeg == InnerPos.getOperand(0) && + TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) { + return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg); + } + } + + return SDValue(); +} + // MatchRotate - Handle an 'or' of two operands. If this is one of the many // idioms for rotate, and if the target supports rotation instructions, generate -// a rot[lr]. +// a rot[lr]. This also matches funnel shift patterns, similar to rotation but +// with different shifted sources. SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { // Must be a legal type. Expanded 'n promoted things won't work with rotates. EVT VT = LHS.getValueType(); if (!TLI.isTypeLegal(VT)) return SDValue(); - // The target must have at least one rotate flavor. + // The target must have at least one rotate/funnel flavor. bool HasROTL = hasOperation(ISD::ROTL, VT); bool HasROTR = hasOperation(ISD::ROTR, VT); - if (!HasROTL && !HasROTR) + bool HasFSHL = hasOperation(ISD::FSHL, VT); + bool HasFSHR = hasOperation(ISD::FSHR, VT); + if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR) return SDValue(); // Check for truncated rotate. @@ -6308,12 +6542,13 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { // At this point we've matched or extracted a shift op on each side. - if (LHSShift.getOperand(0) != RHSShift.getOperand(0)) - return SDValue(); // Not shifting the same value. - if (LHSShift.getOpcode() == RHSShift.getOpcode()) return SDValue(); // Shifts must disagree. + bool IsRotate = LHSShift.getOperand(0) == RHSShift.getOperand(0); + if (!IsRotate && !(HasFSHL || HasFSHR)) + return SDValue(); // Requires funnel shift support. + // Canonicalize shl to left side in a shl/srl pair. if (RHSShift.getOpcode() == ISD::SHL) { std::swap(LHS, RHS); @@ -6329,13 +6564,21 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1) // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2) + // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1) + // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2) + // iff C1+C2 == EltSizeInBits auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS, ConstantSDNode *RHS) { return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits; }; if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) { - SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, - LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt); + SDValue Res; + if (IsRotate && (HasROTL || HasROTR)) + Res = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg, + HasROTL ? LHSShiftAmt : RHSShiftAmt); + else + Res = DAG.getNode(HasFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg, + RHSShiftArg, HasFSHL ? LHSShiftAmt : RHSShiftAmt); // If there is an AND of either shifted operand, apply it to the result. if (LHSMask.getNode() || RHSMask.getNode()) { @@ -6353,10 +6596,10 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits)); } - Rot = DAG.getNode(ISD::AND, DL, VT, Rot, Mask); + Res = DAG.getNode(ISD::AND, DL, VT, Res, Mask); } - return Rot; + return Res; } // If there is a mask here, and we have a variable shift, we can't be sure @@ -6379,13 +6622,29 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { RExtOp0 = RHSShiftAmt.getOperand(0); } - SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, - LExtOp0, RExtOp0, ISD::ROTL, ISD::ROTR, DL); + if (IsRotate && (HasROTL || HasROTR)) { + SDValue TryL = + MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0, + RExtOp0, ISD::ROTL, ISD::ROTR, DL); + if (TryL) + return TryL; + + SDValue TryR = + MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0, + LExtOp0, ISD::ROTR, ISD::ROTL, DL); + if (TryR) + return TryR; + } + + SDValue TryL = + MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt, + LExtOp0, RExtOp0, ISD::FSHL, ISD::FSHR, DL); if (TryL) return TryL; - SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, - RExtOp0, LExtOp0, ISD::ROTR, ISD::ROTL, DL); + SDValue TryR = + MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt, + RExtOp0, LExtOp0, ISD::FSHR, ISD::FSHL, DL); if (TryR) return TryR; @@ -6610,9 +6869,9 @@ SDValue DAGCombiner::MatchStoreCombine(StoreSDNode *N) { if (LegalOperations && !TLI.isOperationLegal(ISD::STORE, VT)) return SDValue(); - // Check if all the bytes of the combined value we are looking at are stored - // to the same base address. Collect bytes offsets from Base address into - // ByteOffsets. + // Check if all the bytes of the combined value we are looking at are stored + // to the same base address. Collect bytes offsets from Base address into + // ByteOffsets. SDValue CombinedValue; SmallVector<int64_t, 8> ByteOffsets(Width, INT64_MAX); int64_t FirstOffset = INT64_MAX; @@ -6630,17 +6889,16 @@ SDValue DAGCombiner::MatchStoreCombine(StoreSDNode *N) { SDValue Value = Trunc.getOperand(0); if (Value.getOpcode() == ISD::SRL || Value.getOpcode() == ISD::SRA) { - ConstantSDNode *ShiftOffset = - dyn_cast<ConstantSDNode>(Value.getOperand(1)); - // Trying to match the following pattern. The shift offset must be + auto *ShiftOffset = dyn_cast<ConstantSDNode>(Value.getOperand(1)); + // Trying to match the following pattern. The shift offset must be // a constant and a multiple of 8. It is the byte offset in "y". - // + // // x = srl y, offset - // i8 z = trunc x + // i8 z = trunc x // store z, ... if (!ShiftOffset || (ShiftOffset->getSExtValue() % 8)) return SDValue(); - + Offset = ShiftOffset->getSExtValue()/8; Value = Value.getOperand(0); } @@ -6685,7 +6943,7 @@ SDValue DAGCombiner::MatchStoreCombine(StoreSDNode *N) { assert(FirstOffset != INT64_MAX && "First byte offset must be set"); assert(FirstStore && "First store must be set"); - // Check if the bytes of the combined value we are looking at match with + // Check if the bytes of the combined value we are looking at match with // either big or little endian value store. Optional<bool> IsBigEndian = isBigEndian(ByteOffsets, FirstOffset); if (!IsBigEndian.hasValue()) @@ -7030,20 +7288,22 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { SDLoc DL(N); if (N0.isUndef() && N1.isUndef()) return DAG.getConstant(0, DL, VT); + // fold (xor x, undef) -> undef if (N0.isUndef()) return N0; if (N1.isUndef()) return N1; + // fold (xor c1, c2) -> c1^c2 - ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); - ConstantSDNode *N1C = getAsNonOpaqueConstant(N1); - if (N0C && N1C) - return DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, N0C, N1C); + if (SDValue C = DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, {N0, N1})) + return C; + // canonicalize constant to RHS if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && !DAG.isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::XOR, DL, VT, N1, N0); + // fold (xor x, 0) -> x if (isNullConstant(N1)) return N0; @@ -7058,7 +7318,8 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { // fold !(x cc y) -> (x !cc y) unsigned N0Opcode = N0.getOpcode(); SDValue LHS, RHS, CC; - if (TLI.isConstTrueVal(N1.getNode()) && isSetCCEquivalent(N0, LHS, RHS, CC)) { + if (TLI.isConstTrueVal(N1.getNode()) && + isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/true)) { ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(), LHS.getValueType()); if (!LegalOperations || @@ -7071,6 +7332,21 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { case ISD::SELECT_CC: return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2), N0.getOperand(3), NotCC); + case ISD::STRICT_FSETCC: + case ISD::STRICT_FSETCCS: { + if (N0.hasOneUse()) { + // FIXME Can we handle multiple uses? Could we token factor the chain + // results from the new/old setcc? + SDValue SetCC = DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC, + N0.getOperand(0), + N0Opcode == ISD::STRICT_FSETCCS); + CombineTo(N, SetCC); + DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), SetCC.getValue(1)); + recursivelyDeleteUnusedNodes(N0.getNode()); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } + break; + } } } } @@ -7405,15 +7681,29 @@ SDValue DAGCombiner::visitRotate(SDNode *N) { } // fold (rot x, c) -> (rot x, c % BitSize) - // TODO - support non-uniform vector amounts. - if (ConstantSDNode *Cst = isConstOrConstSplat(N1)) { - if (Cst->getAPIntValue().uge(Bitsize)) { - uint64_t RotAmt = Cst->getAPIntValue().urem(Bitsize); - return DAG.getNode(N->getOpcode(), dl, VT, N0, - DAG.getConstant(RotAmt, dl, N1.getValueType())); - } + bool OutOfRange = false; + auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) { + OutOfRange |= C->getAPIntValue().uge(Bitsize); + return true; + }; + if (ISD::matchUnaryPredicate(N1, MatchOutOfRange) && OutOfRange) { + EVT AmtVT = N1.getValueType(); + SDValue Bits = DAG.getConstant(Bitsize, dl, AmtVT); + if (SDValue Amt = + DAG.FoldConstantArithmetic(ISD::UREM, dl, AmtVT, {N1, Bits})) + return DAG.getNode(N->getOpcode(), dl, VT, N0, Amt); } + // rot i16 X, 8 --> bswap X + auto *RotAmtC = isConstOrConstSplat(N1); + if (RotAmtC && RotAmtC->getAPIntValue() == 8 && + VT.getScalarSizeInBits() == 16 && hasOperation(ISD::BSWAP, VT)) + return DAG.getNode(ISD::BSWAP, dl, VT, N0); + + // Simplify the operands using demanded-bits information. + if (SimplifyDemandedBits(SDValue(N, 0))) + return SDValue(N, 0); + // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))). if (N1.getOpcode() == ISD::TRUNCATE && N1.getOperand(0).getOpcode() == ISD::AND) { @@ -7430,12 +7720,11 @@ SDValue DAGCombiner::visitRotate(SDNode *N) { EVT ShiftVT = C1->getValueType(0); bool SameSide = (N->getOpcode() == NextOp); unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB; - if (SDValue CombinedShift = - DAG.FoldConstantArithmetic(CombineOp, dl, ShiftVT, C1, C2)) { + if (SDValue CombinedShift = DAG.FoldConstantArithmetic( + CombineOp, dl, ShiftVT, {N1, N0.getOperand(1)})) { SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT); SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic( - ISD::SREM, dl, ShiftVT, CombinedShift.getNode(), - BitsizeC.getNode()); + ISD::SREM, dl, ShiftVT, {CombinedShift, BitsizeC}); return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0), CombinedShiftNorm); } @@ -7471,8 +7760,8 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC && TLI.getBooleanContents(N00.getOperand(0).getValueType()) == TargetLowering::ZeroOrNegativeOneBooleanContent) { - if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, - N01CV, N1CV)) + if (SDValue C = + DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N01, N1})) return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C); } } @@ -7482,10 +7771,8 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { ConstantSDNode *N1C = isConstOrConstSplat(N1); // fold (shl c1, c2) -> c1<<c2 - // TODO - support non-uniform vector shift amounts. - ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); - if (N0C && N1C && !N1C->isOpaque()) - return DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, N0C, N1C); + if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N0, N1})) + return C; if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; @@ -7502,8 +7789,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1); } - // TODO - support non-uniform vector shift amounts. - if (N1C && SimplifyDemandedBits(SDValue(N, 0))) + if (SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2)) @@ -7691,9 +7977,90 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { if (SDValue NewSHL = visitShiftByConstant(N)) return NewSHL; + // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)). + if (N0.getOpcode() == ISD::VSCALE) + if (ConstantSDNode *NC1 = isConstOrConstSplat(N->getOperand(1))) { + auto DL = SDLoc(N); + APInt C0 = N0.getConstantOperandAPInt(0); + APInt C1 = NC1->getAPIntValue(); + return DAG.getVScale(DL, VT, C0 << C1); + } + return SDValue(); } +// Transform a right shift of a multiply into a multiply-high. +// Examples: +// (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b) +// (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b) +static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG, + const TargetLowering &TLI) { + assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) && + "SRL or SRA node is required here!"); + + // Check the shift amount. Proceed with the transformation if the shift + // amount is constant. + ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1)); + if (!ShiftAmtSrc) + return SDValue(); + + SDLoc DL(N); + + // The operation feeding into the shift must be a multiply. + SDValue ShiftOperand = N->getOperand(0); + if (ShiftOperand.getOpcode() != ISD::MUL) + return SDValue(); + + // Both operands must be equivalent extend nodes. + SDValue LeftOp = ShiftOperand.getOperand(0); + SDValue RightOp = ShiftOperand.getOperand(1); + bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND; + bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND; + + if ((!(IsSignExt || IsZeroExt)) || LeftOp.getOpcode() != RightOp.getOpcode()) + return SDValue(); + + EVT WideVT1 = LeftOp.getValueType(); + EVT WideVT2 = RightOp.getValueType(); + (void)WideVT2; + // Proceed with the transformation if the wide types match. + assert((WideVT1 == WideVT2) && + "Cannot have a multiply node with two different operand types."); + + EVT NarrowVT = LeftOp.getOperand(0).getValueType(); + // Check that the two extend nodes are the same type. + if (NarrowVT != RightOp.getOperand(0).getValueType()) + return SDValue(); + + // Only transform into mulh if mulh for the narrow type is cheaper than + // a multiply followed by a shift. This should also check if mulh is + // legal for NarrowVT on the target. + if (!TLI.isMulhCheaperThanMulShift(NarrowVT)) + return SDValue(); + + // Proceed with the transformation if the wide type is twice as large + // as the narrow type. + unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits(); + if (WideVT1.getScalarSizeInBits() != 2 * NarrowVTSize) + return SDValue(); + + // Check the shift amount with the narrow type size. + // Proceed with the transformation if the shift amount is the width + // of the narrow type. + unsigned ShiftAmt = ShiftAmtSrc->getZExtValue(); + if (ShiftAmt != NarrowVTSize) + return SDValue(); + + // If the operation feeding into the MUL is a sign extend (sext), + // we use mulhs. Othewise, zero extends (zext) use mulhu. + unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU; + + SDValue Result = DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0), + RightOp.getOperand(0)); + return (N->getOpcode() == ISD::SRA ? DAG.getSExtOrTrunc(Result, DL, WideVT1) + : DAG.getZExtOrTrunc(Result, DL, WideVT1)); +} + SDValue DAGCombiner::visitSRA(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -7717,10 +8084,8 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { ConstantSDNode *N1C = isConstOrConstSplat(N1); // fold (sra c1, c2) -> (sra c1, c2) - // TODO - support non-uniform vector shift amounts. - ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); - if (N0C && N1C && !N1C->isOpaque()) - return DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, N0C, N1C); + if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, {N0, N1})) + return C; if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; @@ -7811,7 +8176,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper. // sra (add (shl X, N1C), AddC), N1C --> // sext (add (trunc X to (width - N1C)), AddC') - if (!LegalTypes && N0.getOpcode() == ISD::ADD && N0.hasOneUse() && N1C && + if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() && N1C && N0.getOperand(0).getOpcode() == ISD::SHL && N0.getOperand(0).getOperand(1) == N1 && N0.getOperand(0).hasOneUse()) { if (ConstantSDNode *AddC = isConstOrConstSplat(N0.getOperand(1))) { @@ -7828,7 +8193,8 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { // implementation and/or target-specific overrides (because // non-simple types likely require masking when legalized), but that // restriction may conflict with other transforms. - if (TruncVT.isSimple() && TLI.isTruncateFree(VT, TruncVT)) { + if (TruncVT.isSimple() && isTypeLegal(TruncVT) && + TLI.isTruncateFree(VT, TruncVT)) { SDLoc DL(N); SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT); SDValue ShiftC = DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt). @@ -7871,8 +8237,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { } // Simplify, based on bits shifted out of the LHS. - // TODO - support non-uniform vector shift amounts. - if (N1C && SimplifyDemandedBits(SDValue(N, 0))) + if (SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); // If the sign bit is known to be zero, switch this to a SRL. @@ -7883,6 +8248,11 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { if (SDValue NewSRA = visitShiftByConstant(N)) return NewSRA; + // Try to transform this shift into a multiply-high if + // it matches the appropriate pattern detected in combineShiftToMULH. + if (SDValue MULH = combineShiftToMULH(N, DAG, TLI)) + return MULH; + return SDValue(); } @@ -7903,10 +8273,8 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { ConstantSDNode *N1C = isConstOrConstSplat(N1); // fold (srl c1, c2) -> c1 >>u c2 - // TODO - support non-uniform vector shift amounts. - ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); - if (N0C && N1C && !N1C->isOpaque()) - return DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, N0C, N1C); + if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, {N0, N1})) + return C; if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; @@ -8070,8 +8438,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { // fold operands of srl based on knowledge that the low bits are not // demanded. - // TODO - support non-uniform vector shift amounts. - if (N1C && SimplifyDemandedBits(SDValue(N, 0))) + if (SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); if (N1C && !N1C->isOpaque()) @@ -8111,6 +8478,11 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { } } + // Try to transform this shift into a multiply-high if + // it matches the appropriate pattern detected in combineShiftToMULH. + if (SDValue MULH = combineShiftToMULH(N, DAG, TLI)) + return MULH; + return SDValue(); } @@ -8160,6 +8532,45 @@ SDValue DAGCombiner::visitFunnelShift(SDNode *N) { return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt, SDLoc(N), ShAmtTy)); + + // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive. + // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive. + // TODO - bigendian support once we have test coverage. + // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine? + // TODO - permit LHS EXTLOAD if extensions are shifted out. + if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() && + !DAG.getDataLayout().isBigEndian()) { + auto *LHS = dyn_cast<LoadSDNode>(N0); + auto *RHS = dyn_cast<LoadSDNode>(N1); + if (LHS && RHS && LHS->isSimple() && RHS->isSimple() && + LHS->getAddressSpace() == RHS->getAddressSpace() && + (LHS->hasOneUse() || RHS->hasOneUse()) && ISD::isNON_EXTLoad(RHS) && + ISD::isNON_EXTLoad(LHS)) { + if (DAG.areNonVolatileConsecutiveLoads(LHS, RHS, BitWidth / 8, 1)) { + SDLoc DL(RHS); + uint64_t PtrOff = + IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8); + Align NewAlign = commonAlignment(RHS->getAlign(), PtrOff); + bool Fast = false; + if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT, + RHS->getAddressSpace(), NewAlign, + RHS->getMemOperand()->getFlags(), &Fast) && + Fast) { + SDValue NewPtr = + DAG.getMemBasePlusOffset(RHS->getBasePtr(), PtrOff, DL); + AddToWorklist(NewPtr.getNode()); + SDValue Load = DAG.getLoad( + VT, DL, RHS->getChain(), NewPtr, + RHS->getPointerInfo().getWithOffset(PtrOff), NewAlign, + RHS->getMemOperand()->getFlags(), RHS->getAAInfo()); + // Replace the old load's chain with the new load's chain. + WorklistRemover DeadNodes(*this); + DAG.ReplaceAllUsesOfValueWith(N1.getValue(1), Load.getValue(1)); + return Load; + } + } + } + } } // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2) @@ -8609,7 +9020,7 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { // Create the actual or node if we can generate good code for it. if (!normalizeToSequence) { SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0); - return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1, + return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1, N2_2, Flags); } // Otherwise see if we can optimize to a better pattern. @@ -8825,6 +9236,8 @@ SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) { SDValue N2Elt = N2.getOperand(i); if (N1Elt.isUndef() || N2Elt.isUndef()) continue; + if (N1Elt.getValueType() != N2Elt.getValueType()) + continue; const APInt &C1 = cast<ConstantSDNode>(N1Elt)->getAPIntValue(); const APInt &C2 = cast<ConstantSDNode>(N2Elt)->getAPIntValue(); @@ -9395,8 +9808,7 @@ SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) { SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad, N1.getOperand(1)); - APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); - Mask = Mask.zext(VT.getSizeInBits()); + APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits()); SDLoc DL0(N0); SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift, DAG.getConstant(Mask, DL0, VT)); @@ -9702,8 +10114,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { LN00->getChain(), LN00->getBasePtr(), LN00->getMemoryVT(), LN00->getMemOperand()); - APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); - Mask = Mask.sext(VT.getSizeInBits()); + APInt Mask = N0.getConstantOperandAPInt(1).sext(VT.getSizeInBits()); SDValue And = DAG.getNode(N0.getOpcode(), DL, VT, ExtLoad, DAG.getConstant(Mask, DL, VT)); ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND); @@ -9941,7 +10352,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) && TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) { SDValue Op = N0.getOperand(0); - Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType()); + Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT); AddToWorklist(Op.getNode()); SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT); // Transfer the debug info; the new node is equivalent to N0. @@ -9953,7 +10364,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) { SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT); AddToWorklist(Op.getNode()); - SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType()); + SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT); // We may safely transfer the debug info describing the truncate node over // to the equivalent and operation. DAG.transferDbgValues(N0, And); @@ -9971,8 +10382,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { !TLI.isZExtFree(N0.getValueType(), VT))) { SDValue X = N0.getOperand(0).getOperand(0); X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT); - APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); - Mask = Mask.zext(VT.getSizeInBits()); + APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits()); SDLoc DL(N); return DAG.getNode(ISD::AND, DL, VT, X, DAG.getConstant(Mask, DL, VT)); @@ -10026,8 +10436,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { LN00->getChain(), LN00->getBasePtr(), LN00->getMemoryVT(), LN00->getMemOperand()); - APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); - Mask = Mask.zext(VT.getSizeInBits()); + APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits()); SDLoc DL(N); SDValue And = DAG.getNode(N0.getOpcode(), DL, VT, ExtLoad, DAG.getConstant(Mask, DL, VT)); @@ -10080,23 +10489,22 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { // that the element size of the sext'd result matches the element size of // the compare operands. SDLoc DL(N); - SDValue VecOnes = DAG.getConstant(1, DL, VT); if (VT.getSizeInBits() == N00VT.getSizeInBits()) { - // zext(setcc) -> (and (vsetcc), (1, 1, ...) for vectors. + // zext(setcc) -> zext_in_reg(vsetcc) for vectors. SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0), N0.getOperand(1), N0.getOperand(2)); - return DAG.getNode(ISD::AND, DL, VT, VSetCC, VecOnes); + return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType()); } // If the desired elements are smaller or larger than the source // elements we can use a matching integer vector type and then - // truncate/sign extend. + // truncate/any extend followed by zext_in_reg. EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger(); SDValue VsetCC = DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0), N0.getOperand(1), N0.getOperand(2)); - return DAG.getNode(ISD::AND, DL, VT, DAG.getSExtOrTrunc(VsetCC, DL, VT), - VecOnes); + return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL, + N0.getValueType()); } // zext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc @@ -10127,7 +10535,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { SDLoc DL(N); // Ensure that the shift amount is wide enough for the shifted value. - if (VT.getSizeInBits() >= 256) + if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits()) ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt); return DAG.getNode(N0.getOpcode(), DL, VT, @@ -10187,8 +10595,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { SDLoc DL(N); SDValue X = N0.getOperand(0).getOperand(0); X = DAG.getAnyExtOrTrunc(X, DL, VT); - APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); - Mask = Mask.zext(VT.getSizeInBits()); + APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits()); return DAG.getNode(ISD::AND, DL, VT, X, DAG.getConstant(Mask, DL, VT)); } @@ -10348,6 +10755,45 @@ SDValue DAGCombiner::visitAssertExt(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitAssertAlign(SDNode *N) { + SDLoc DL(N); + + Align AL = cast<AssertAlignSDNode>(N)->getAlign(); + SDValue N0 = N->getOperand(0); + + // Fold (assertalign (assertalign x, AL0), AL1) -> + // (assertalign x, max(AL0, AL1)) + if (auto *AAN = dyn_cast<AssertAlignSDNode>(N0)) + return DAG.getAssertAlign(DL, N0.getOperand(0), + std::max(AL, AAN->getAlign())); + + // In rare cases, there are trivial arithmetic ops in source operands. Sink + // this assert down to source operands so that those arithmetic ops could be + // exposed to the DAG combining. + switch (N0.getOpcode()) { + default: + break; + case ISD::ADD: + case ISD::SUB: { + unsigned AlignShift = Log2(AL); + SDValue LHS = N0.getOperand(0); + SDValue RHS = N0.getOperand(1); + unsigned LHSAlignShift = DAG.computeKnownBits(LHS).countMinTrailingZeros(); + unsigned RHSAlignShift = DAG.computeKnownBits(RHS).countMinTrailingZeros(); + if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) { + if (LHSAlignShift < AlignShift) + LHS = DAG.getAssertAlign(DL, LHS, AL); + if (RHSAlignShift < AlignShift) + RHS = DAG.getAssertAlign(DL, RHS, AL); + return DAG.getNode(N0.getOpcode(), DL, N0.getValueType(), LHS, RHS); + } + break; + } + } + + return SDValue(); +} + /// If the result of a wider load is shifted to right of N bits and then /// truncated to a narrower type and where N is a multiple of number of bits of /// the narrower type, transform it to a narrower load from address + N / num of @@ -10428,9 +10874,8 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { } // At this point, we must have a load or else we can't do the transform. - if (!isa<LoadSDNode>(N0)) return SDValue(); - - auto *LN0 = cast<LoadSDNode>(N0); + auto *LN0 = dyn_cast<LoadSDNode>(N0); + if (!LN0) return SDValue(); // Because a SRL must be assumed to *need* to zero-extend the high bits // (as opposed to anyext the high bits), we can't combine the zextload @@ -10449,8 +10894,7 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { SDNode *Mask = *(SRL->use_begin()); if (Mask->getOpcode() == ISD::AND && isa<ConstantSDNode>(Mask->getOperand(1))) { - const APInt &ShiftMask = - cast<ConstantSDNode>(Mask->getOperand(1))->getAPIntValue(); + const APInt& ShiftMask = Mask->getConstantOperandAPInt(1); if (ShiftMask.isMask()) { EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(), ShiftMask.countTrailingOnes()); @@ -10480,7 +10924,7 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { LoadSDNode *LN0 = cast<LoadSDNode>(N0); // Reducing the width of a volatile load is illegal. For atomics, we may be - // able to reduce the width provided we never widen again. (see D66309) + // able to reduce the width provided we never widen again. (see D66309) if (!LN0->isSimple() || !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt)) return SDValue(); @@ -10561,26 +11005,27 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); - EVT EVT = cast<VTSDNode>(N1)->getVT(); + EVT ExtVT = cast<VTSDNode>(N1)->getVT(); unsigned VTBits = VT.getScalarSizeInBits(); - unsigned EVTBits = EVT.getScalarSizeInBits(); + unsigned ExtVTBits = ExtVT.getScalarSizeInBits(); + // sext_vector_inreg(undef) = 0 because the top bit will all be the same. if (N0.isUndef()) - return DAG.getUNDEF(VT); + return DAG.getConstant(0, SDLoc(N), VT); // fold (sext_in_reg c1) -> c1 if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1); // If the input is already sign extended, just drop the extension. - if (DAG.ComputeNumSignBits(N0) >= VTBits-EVTBits+1) + if (DAG.ComputeNumSignBits(N0) >= (VTBits - ExtVTBits + 1)) return N0; // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG && - EVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT())) - return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, - N0.getOperand(0), N1); + ExtVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT())) + return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0.getOperand(0), + N1); // fold (sext_in_reg (sext x)) -> (sext x) // fold (sext_in_reg (aext x)) -> (sext x) @@ -10589,8 +11034,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) { SDValue N00 = N0.getOperand(0); unsigned N00Bits = N00.getScalarValueSizeInBits(); - if ((N00Bits <= EVTBits || - (N00Bits - DAG.ComputeNumSignBits(N00)) < EVTBits) && + if ((N00Bits <= ExtVTBits || + (N00Bits - DAG.ComputeNumSignBits(N00)) < ExtVTBits) && (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT))) return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00); } @@ -10599,7 +11044,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { if ((N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG || N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG || N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) && - N0.getOperand(0).getScalarValueSizeInBits() == EVTBits) { + N0.getOperand(0).getScalarValueSizeInBits() == ExtVTBits) { if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT)) return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT, @@ -10610,14 +11055,14 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { // iff we are extending the source sign bit. if (N0.getOpcode() == ISD::ZERO_EXTEND) { SDValue N00 = N0.getOperand(0); - if (N00.getScalarValueSizeInBits() == EVTBits && + if (N00.getScalarValueSizeInBits() == ExtVTBits && (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT))) return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1); } // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero. - if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, EVTBits - 1))) - return DAG.getZeroExtendInReg(N0, SDLoc(N), EVT.getScalarType()); + if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, ExtVTBits - 1))) + return DAG.getZeroExtendInReg(N0, SDLoc(N), ExtVT); // fold operands of sext_in_reg based on knowledge that the top bits are not // demanded. @@ -10634,11 +11079,11 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above. if (N0.getOpcode() == ISD::SRL) { if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1))) - if (ShAmt->getAPIntValue().ule(VTBits - EVTBits)) { + if (ShAmt->getAPIntValue().ule(VTBits - ExtVTBits)) { // We can turn this into an SRA iff the input to the SRL is already sign // extended enough. unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0)); - if (((VTBits - EVTBits) - ShAmt->getZExtValue()) < InSignBits) + if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits) return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0), N0.getOperand(1)); } @@ -10650,14 +11095,14 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { // extends that the target does support. if (ISD::isEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && - EVT == cast<LoadSDNode>(N0)->getMemoryVT() && + ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() && ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() && N0.hasOneUse()) || - TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) { + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) { LoadSDNode *LN0 = cast<LoadSDNode>(N0); SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT, LN0->getChain(), - LN0->getBasePtr(), EVT, + LN0->getBasePtr(), ExtVT, LN0->getMemOperand()); CombineTo(N, ExtLoad); CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); @@ -10667,13 +11112,13 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse() && - EVT == cast<LoadSDNode>(N0)->getMemoryVT() && + ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() && ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) && - TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) { + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) { LoadSDNode *LN0 = cast<LoadSDNode>(N0); SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT, LN0->getChain(), - LN0->getBasePtr(), EVT, + LN0->getBasePtr(), ExtVT, LN0->getMemOperand()); CombineTo(N, ExtLoad); CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); @@ -10681,11 +11126,10 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { } // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16)) - if (EVTBits <= 16 && N0.getOpcode() == ISD::OR) { + if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) { if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0), N0.getOperand(1), false)) - return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, - BSwap, N1); + return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, BSwap, N1); } return SDValue(); @@ -10695,8 +11139,9 @@ SDValue DAGCombiner::visitSIGN_EXTEND_VECTOR_INREG(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); + // sext_vector_inreg(undef) = 0 because the top bit will all be the same. if (N0.isUndef()) - return DAG.getUNDEF(VT); + return DAG.getConstant(0, SDLoc(N), VT); if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) return Res; @@ -10711,8 +11156,9 @@ SDValue DAGCombiner::visitZERO_EXTEND_VECTOR_INREG(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); + // zext_vector_inreg(undef) = 0 because the top bits will be zero. if (N0.isUndef()) - return DAG.getUNDEF(VT); + return DAG.getConstant(0, SDLoc(N), VT); if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) return Res; @@ -10788,13 +11234,12 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { SDValue EltNo = N0->getOperand(1); if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) { int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue(); - EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout()); int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1)); SDLoc DL(N); return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy, DAG.getBitcast(NVT, N0.getOperand(0)), - DAG.getConstant(Index, DL, IndexTy)); + DAG.getVectorIdxConstant(Index, DL)); } } @@ -10832,7 +11277,9 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { // Attempt to pre-truncate BUILD_VECTOR sources. if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations && - TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType())) { + TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()) && + // Avoid creating illegal types if running after type legalizer. + (!LegalTypes || TLI.isTypeLegal(VT.getScalarType()))) { SDLoc DL(N); EVT SVT = VT.getScalarType(); SmallVector<SDValue, 8> TruncOps; @@ -10961,10 +11408,9 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) { SDLoc SL(N); - EVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout()); unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1; return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, VT, VecSrc, - DAG.getConstant(Idx, SL, IdxVT)); + DAG.getVectorIdxConstant(Idx, SL)); } } @@ -11064,14 +11510,14 @@ SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) { unsigned LD1Bytes = LD1VT.getStoreSize(); if (ISD::isNON_EXTLoad(LD2) && LD2->hasOneUse() && DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1)) { - unsigned Align = LD1->getAlignment(); - unsigned NewAlign = DAG.getDataLayout().getABITypeAlignment( + Align Alignment = LD1->getAlign(); + Align NewAlign = DAG.getDataLayout().getABITypeAlign( VT.getTypeForEVT(*DAG.getContext())); - if (NewAlign <= Align && + if (NewAlign <= Alignment && (!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT))) return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(), - LD1->getPointerInfo(), Align); + LD1->getPointerInfo(), Alignment); } return SDValue(); @@ -11389,6 +11835,20 @@ SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) { return CombineConsecutiveLoads(N, VT); } +SDValue DAGCombiner::visitFREEZE(SDNode *N) { + SDValue N0 = N->getOperand(0); + + // (freeze (freeze x)) -> (freeze x) + if (N0.getOpcode() == ISD::FREEZE) + return N0; + + // If the input is a constant, return it. + if (isa<ConstantSDNode>(N0) || isa<ConstantFPSDNode>(N0)) + return N0; + + return SDValue(); +} + /// We know that BV is a build_vector node with Constant, ConstantFP or Undef /// operands. DstEltVT indicates the destination element value type. SDValue DAGCombiner:: @@ -11519,7 +11979,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { const TargetOptions &Options = DAG.getTarget().Options; // Floating-point multiply-add with intermediate rounding. - bool HasFMAD = (LegalOperations && TLI.isFMADLegalForFAddFSub(DAG, N)); + bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N)); // Floating-point multiply-add without intermediate rounding. bool HasFMA = @@ -11532,13 +11992,14 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDNodeFlags Flags = N->getFlags(); bool CanFuse = Options.UnsafeFPMath || isContractable(N); + bool CanReassociate = + Options.UnsafeFPMath || N->getFlags().hasAllowReassociation(); bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast || CanFuse || HasFMAD); // If the addition is not contractable, do not combine. if (!AllowFusionGlobally && !isContractable(N)) return SDValue(); - const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo(); if (STI && STI->generateFMAsInMachineCombiner(OptLevel)) return SDValue(); @@ -11573,6 +12034,30 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { N1.getOperand(0), N1.getOperand(1), N0, Flags); } + // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E) + // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E) + // This requires reassociation because it changes the order of operations. + SDValue FMA, E; + if (CanReassociate && N0.getOpcode() == PreferredFusedOpcode && + N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() && + N0.getOperand(2).hasOneUse()) { + FMA = N0; + E = N1; + } else if (CanReassociate && N1.getOpcode() == PreferredFusedOpcode && + N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() && + N1.getOperand(2).hasOneUse()) { + FMA = N1; + E = N0; + } + if (FMA && E) { + SDValue A = FMA.getOperand(0); + SDValue B = FMA.getOperand(1); + SDValue C = FMA.getOperand(2).getOperand(0); + SDValue D = FMA.getOperand(2).getOperand(1); + SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E, Flags); + return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE, Flags); + } + // Look through FP_EXTEND nodes to do more combining. // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) @@ -11606,33 +12091,6 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { // More folding opportunities when target permits. if (Aggressive) { - // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) - if (CanFuse && - N0.getOpcode() == PreferredFusedOpcode && - N0.getOperand(2).getOpcode() == ISD::FMUL && - N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - N0.getOperand(0), N0.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - N0.getOperand(2).getOperand(0), - N0.getOperand(2).getOperand(1), - N1, Flags), Flags); - } - - // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) - if (CanFuse && - N1->getOpcode() == PreferredFusedOpcode && - N1.getOperand(2).getOpcode() == ISD::FMUL && - N1->hasOneUse() && N1.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - N1.getOperand(0), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - N1.getOperand(2).getOperand(0), - N1.getOperand(2).getOperand(1), - N0, Flags), Flags); - } - - // fold (fadd (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y, (fma (fpext u), (fpext v), z)) auto FoldFAddFMAFPExtFMul = [&] ( @@ -11736,7 +12194,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { const TargetOptions &Options = DAG.getTarget().Options; // Floating-point multiply-add with intermediate rounding. - bool HasFMAD = (LegalOperations && TLI.isFMADLegalForFAddFSub(DAG, N)); + bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N)); // Floating-point multiply-add without intermediate rounding. bool HasFMA = @@ -11756,13 +12214,13 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { if (!AllowFusionGlobally && !isContractable(N)) return SDValue(); - const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo(); if (STI && STI->generateFMAsInMachineCombiner(OptLevel)) return SDValue(); // Always prefer FMAD to FMA for precision. unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); + bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros(); // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. @@ -11773,19 +12231,43 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { }; // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z)) - if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - N0.getOperand(0), N0.getOperand(1), - DAG.getNode(ISD::FNEG, SL, VT, N1), Flags); - } + auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) { + if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) { + return DAG.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0), + XY.getOperand(1), DAG.getNode(ISD::FNEG, SL, VT, Z), + Flags); + } + return SDValue(); + }; // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x) // Note: Commutes FSUB operands. - if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - N1.getOperand(0)), - N1.getOperand(1), N0, Flags); + auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) { + if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) { + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)), + YZ.getOperand(1), X, Flags); + } + return SDValue(); + }; + + // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)), + // prefer to fold the multiply with fewer uses. + if (isContractableFMUL(N0) && isContractableFMUL(N1) && + (N0.getNode()->use_size() > N1.getNode()->use_size())) { + // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b)) + if (SDValue V = tryToFoldXSubYZ(N0, N1)) + return V; + // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d))) + if (SDValue V = tryToFoldXYSubZ(N0, N1)) + return V; + } else { + // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z)) + if (SDValue V = tryToFoldXYSubZ(N0, N1)) + return V; + // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x) + if (SDValue V = tryToFoldXSubYZ(N0, N1)) + return V; } // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z)) @@ -11902,7 +12384,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { // -> (fma (fneg y), z, (fma (fneg u), v, x)) if (CanFuse && N1.getOpcode() == PreferredFusedOpcode && isContractableFMUL(N1.getOperand(2)) && - N1->hasOneUse()) { + N1->hasOneUse() && NoSignedZero) { SDValue N20 = N1.getOperand(2).getOperand(0); SDValue N21 = N1.getOperand(2).getOperand(1); return DAG.getNode(PreferredFusedOpcode, SL, VT, @@ -12055,7 +12537,7 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) { // Floating-point multiply-add with intermediate rounding. This can result // in a less precise result due to the changed rounding order. bool HasFMAD = Options.UnsafeFPMath && - (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); + (LegalOperations && TLI.isFMADLegal(DAG, N)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) @@ -12132,6 +12614,9 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { const TargetOptions &Options = DAG.getTarget().Options; const SDNodeFlags Flags = N->getFlags(); + if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags)) + return R; + // fold vector ops if (VT.isVector()) if (SDValue FoldedVOp = SimplifyVBinOp(N)) @@ -12155,18 +12640,16 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { return NewSel; // fold (fadd A, (fneg B)) -> (fsub A, B) - if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) && - TLI.isNegatibleForFree(N1, DAG, LegalOperations, ForCodeSize) == 2) - return DAG.getNode( - ISD::FSUB, DL, VT, N0, - TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags); + if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) + if (SDValue NegN1 = TLI.getCheaperNegatedExpression( + N1, DAG, LegalOperations, ForCodeSize)) + return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1, Flags); // fold (fadd (fneg A), B) -> (fsub B, A) - if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) && - TLI.isNegatibleForFree(N0, DAG, LegalOperations, ForCodeSize) == 2) - return DAG.getNode( - ISD::FSUB, DL, VT, N1, - TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize), Flags); + if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) + if (SDValue NegN0 = TLI.getCheaperNegatedExpression( + N0, DAG, LegalOperations, ForCodeSize)) + return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0, Flags); auto isFMulNegTwo = [](SDValue FMul) { if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL) @@ -12311,6 +12794,9 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { const TargetOptions &Options = DAG.getTarget().Options; const SDNodeFlags Flags = N->getFlags(); + if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags)) + return R; + // fold vector ops if (VT.isVector()) if (SDValue FoldedVOp = SimplifyVBinOp(N)) @@ -12345,8 +12831,9 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { if (N0CFP && N0CFP->isZero()) { if (N0CFP->isNegative() || (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) { - if (TLI.isNegatibleForFree(N1, DAG, LegalOperations, ForCodeSize)) - return TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + if (SDValue NegN1 = + TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize)) + return NegN1; if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT)) return DAG.getNode(ISD::FNEG, DL, VT, N1, Flags); } @@ -12364,10 +12851,9 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { } // fold (fsub A, (fneg B)) -> (fadd A, B) - if (TLI.isNegatibleForFree(N1, DAG, LegalOperations, ForCodeSize)) - return DAG.getNode( - ISD::FADD, DL, VT, N0, - TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags); + if (SDValue NegN1 = + TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize)) + return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1, Flags); // FSUB -> FMA combines: if (SDValue Fused = visitFSUBForFMACombine(N)) { @@ -12378,21 +12864,6 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { return SDValue(); } -/// Return true if both inputs are at least as cheap in negated form and at -/// least one input is strictly cheaper in negated form. -bool DAGCombiner::isCheaperToUseNegatedFPOps(SDValue X, SDValue Y) { - if (char LHSNeg = - TLI.isNegatibleForFree(X, DAG, LegalOperations, ForCodeSize)) - if (char RHSNeg = - TLI.isNegatibleForFree(Y, DAG, LegalOperations, ForCodeSize)) - // Both negated operands are at least as cheap as their counterparts. - // Check to see if at least one is cheaper negated. - if (LHSNeg == 2 || RHSNeg == 2) - return true; - - return false; -} - SDValue DAGCombiner::visitFMUL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -12403,6 +12874,9 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { const TargetOptions &Options = DAG.getTarget().Options; const SDNodeFlags Flags = N->getFlags(); + if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags)) + return R; + // fold vector ops if (VT.isVector()) { // This just handles C1 * C2 for vectors. Other vector folds are below. @@ -12464,13 +12938,18 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { return DAG.getNode(ISD::FNEG, DL, VT, N0); // -N0 * -N1 --> N0 * N1 - if (isCheaperToUseNegatedFPOps(N0, N1)) { - SDValue NegN0 = - TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); - SDValue NegN1 = - TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + TargetLowering::NegatibleCost CostN0 = + TargetLowering::NegatibleCost::Expensive; + TargetLowering::NegatibleCost CostN1 = + TargetLowering::NegatibleCost::Expensive; + SDValue NegN0 = + TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0); + SDValue NegN1 = + TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1); + if (NegN0 && NegN1 && + (CostN0 == TargetLowering::NegatibleCost::Cheaper || + CostN1 == TargetLowering::NegatibleCost::Cheaper)) return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1, Flags); - } // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X)) // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X) @@ -12549,13 +13028,18 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { } // (-N0 * -N1) + N2 --> (N0 * N1) + N2 - if (isCheaperToUseNegatedFPOps(N0, N1)) { - SDValue NegN0 = - TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); - SDValue NegN1 = - TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + TargetLowering::NegatibleCost CostN0 = + TargetLowering::NegatibleCost::Expensive; + TargetLowering::NegatibleCost CostN1 = + TargetLowering::NegatibleCost::Expensive; + SDValue NegN0 = + TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0); + SDValue NegN1 = + TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1); + if (NegN0 && NegN1 && + (CostN0 == TargetLowering::NegatibleCost::Cheaper || + CostN1 == TargetLowering::NegatibleCost::Cheaper)) return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2, Flags); - } if (UnsafeFPMath) { if (N0CFP && N0CFP->isZero()) @@ -12641,13 +13125,10 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z)) // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z)) - if (!TLI.isFNegFree(VT) && - TLI.isNegatibleForFree(SDValue(N, 0), DAG, LegalOperations, - ForCodeSize) == 2) - return DAG.getNode(ISD::FNEG, DL, VT, - TLI.getNegatedExpression(SDValue(N, 0), DAG, - LegalOperations, ForCodeSize), - Flags); + if (!TLI.isFNegFree(VT)) + if (SDValue Neg = TLI.getCheaperNegatedExpression( + SDValue(N, 0), DAG, LegalOperations, ForCodeSize)) + return DAG.getNode(ISD::FNEG, DL, VT, Neg, Flags); return SDValue(); } @@ -12664,7 +13145,7 @@ SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) { // that only minsize should restrict this. bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath; const SDNodeFlags Flags = N->getFlags(); - if (!UnsafeMath && !Flags.hasAllowReciprocal()) + if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal())) return SDValue(); // Skip if current node is a reciprocal/fneg-reciprocal. @@ -12735,6 +13216,9 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { const TargetOptions &Options = DAG.getTarget().Options; SDNodeFlags Flags = N->getFlags(); + if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags)) + return R; + // fold vector ops if (VT.isVector()) if (SDValue FoldedVOp = SimplifyVBinOp(N)) @@ -12794,37 +13278,62 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { } else if (N1.getOpcode() == ISD::FMUL) { // Look through an FMUL. Even though this won't remove the FDIV directly, // it's still worthwhile to get rid of the FSQRT if possible. - SDValue SqrtOp; - SDValue OtherOp; + SDValue Sqrt, Y; if (N1.getOperand(0).getOpcode() == ISD::FSQRT) { - SqrtOp = N1.getOperand(0); - OtherOp = N1.getOperand(1); + Sqrt = N1.getOperand(0); + Y = N1.getOperand(1); } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) { - SqrtOp = N1.getOperand(1); - OtherOp = N1.getOperand(0); + Sqrt = N1.getOperand(1); + Y = N1.getOperand(0); } - if (SqrtOp.getNode()) { + if (Sqrt.getNode()) { + // If the other multiply operand is known positive, pull it into the + // sqrt. That will eliminate the division if we convert to an estimate: + // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z) + // TODO: Also fold the case where A == Z (fabs is missing). + if (Flags.hasAllowReassociation() && N1.hasOneUse() && + N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse() && + Y.getOpcode() == ISD::FABS && Y.hasOneUse()) { + SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, Y.getOperand(0), + Y.getOperand(0), Flags); + SDValue AAZ = + DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags); + if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags)) + return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags); + + // Estimate creation failed. Clean up speculatively created nodes. + recursivelyDeleteUnusedNodes(AAZ.getNode()); + } + // We found a FSQRT, so try to make this fold: - // x / (y * sqrt(z)) -> x * (rsqrt(z) / y) - if (SDValue RV = buildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) { - RV = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, RV, OtherOp, Flags); - AddToWorklist(RV.getNode()); - return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); + // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y) + if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) { + SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y, Flags); + AddToWorklist(Div.getNode()); + return DAG.getNode(ISD::FMUL, DL, VT, N0, Div, Flags); } } } // Fold into a reciprocal estimate and multiply instead of a real divide. - if (SDValue RV = BuildDivEstimate(N0, N1, Flags)) - return RV; + if (Options.NoInfsFPMath || Flags.hasNoInfs()) + if (SDValue RV = BuildDivEstimate(N0, N1, Flags)) + return RV; } // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y) - if (isCheaperToUseNegatedFPOps(N0, N1)) - return DAG.getNode( - ISD::FDIV, SDLoc(N), VT, - TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize), - TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags); + TargetLowering::NegatibleCost CostN0 = + TargetLowering::NegatibleCost::Expensive; + TargetLowering::NegatibleCost CostN1 = + TargetLowering::NegatibleCost::Expensive; + SDValue NegN0 = + TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0); + SDValue NegN1 = + TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1); + if (NegN0 && NegN1 && + (CostN0 == TargetLowering::NegatibleCost::Cheaper || + CostN1 == TargetLowering::NegatibleCost::Cheaper)) + return DAG.getNode(ISD::FDIV, SDLoc(N), VT, NegN0, NegN1, Flags); return SDValue(); } @@ -12835,6 +13344,10 @@ SDValue DAGCombiner::visitFREM(SDNode *N) { ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0); ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1); EVT VT = N->getValueType(0); + SDNodeFlags Flags = N->getFlags(); + + if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags)) + return R; // fold (frem c1, c2) -> fmod(c1,c2) if (N0CFP && N1CFP) @@ -12848,8 +13361,12 @@ SDValue DAGCombiner::visitFREM(SDNode *N) { SDValue DAGCombiner::visitFSQRT(SDNode *N) { SDNodeFlags Flags = N->getFlags(); - if (!DAG.getTarget().Options.UnsafeFPMath && - !Flags.hasApproximateFuncs()) + const TargetOptions &Options = DAG.getTarget().Options; + + // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as: + // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN + if ((!Options.UnsafeFPMath && !Flags.hasApproximateFuncs()) || + (!Options.NoInfsFPMath && !Flags.hasNoInfs())) return SDValue(); SDValue N0 = N->getOperand(0); @@ -13061,33 +13578,24 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) { } // The next optimizations are desirable only if SELECT_CC can be lowered. - if (TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT) || !LegalOperations) { - // fold (sint_to_fp (setcc x, y, cc)) -> (select_cc x, y, -1.0, 0.0,, cc) - if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 && - !VT.isVector() && - (!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) { - SDLoc DL(N); - SDValue Ops[] = - { N0.getOperand(0), N0.getOperand(1), - DAG.getConstantFP(-1.0, DL, VT), DAG.getConstantFP(0.0, DL, VT), - N0.getOperand(2) }; - return DAG.getNode(ISD::SELECT_CC, DL, VT, Ops); - } + // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0) + if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 && + !VT.isVector() && + (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) { + SDLoc DL(N); + return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(-1.0, DL, VT), + DAG.getConstantFP(0.0, DL, VT)); + } - // fold (sint_to_fp (zext (setcc x, y, cc))) -> - // (select_cc x, y, 1.0, 0.0,, cc) - if (N0.getOpcode() == ISD::ZERO_EXTEND && - N0.getOperand(0).getOpcode() == ISD::SETCC &&!VT.isVector() && - (!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) { - SDLoc DL(N); - SDValue Ops[] = - { N0.getOperand(0).getOperand(0), N0.getOperand(0).getOperand(1), - DAG.getConstantFP(1.0, DL, VT), DAG.getConstantFP(0.0, DL, VT), - N0.getOperand(0).getOperand(2) }; - return DAG.getNode(ISD::SELECT_CC, DL, VT, Ops); - } + // fold (sint_to_fp (zext (setcc x, y, cc))) -> + // (select (setcc x, y, cc), 1.0, 0.0) + if (N0.getOpcode() == ISD::ZERO_EXTEND && + N0.getOperand(0).getOpcode() == ISD::SETCC && !VT.isVector() && + (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) { + SDLoc DL(N); + return DAG.getSelect(DL, VT, N0.getOperand(0), + DAG.getConstantFP(1.0, DL, VT), + DAG.getConstantFP(0.0, DL, VT)); } if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI)) @@ -13121,19 +13629,12 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) { return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0); } - // The next optimizations are desirable only if SELECT_CC can be lowered. - if (TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT) || !LegalOperations) { - // fold (uint_to_fp (setcc x, y, cc)) -> (select_cc x, y, -1.0, 0.0,, cc) - if (N0.getOpcode() == ISD::SETCC && !VT.isVector() && - (!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) { - SDLoc DL(N); - SDValue Ops[] = - { N0.getOperand(0), N0.getOperand(1), - DAG.getConstantFP(1.0, DL, VT), DAG.getConstantFP(0.0, DL, VT), - N0.getOperand(2) }; - return DAG.getNode(ISD::SELECT_CC, DL, VT, Ops); - } + // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0) + if (N0.getOpcode() == ISD::SETCC && !VT.isVector() && + (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) { + SDLoc DL(N); + return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(1.0, DL, VT), + DAG.getConstantFP(0.0, DL, VT)); } if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI)) @@ -13378,12 +13879,14 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) { if (isConstantFPBuildVectorOrConstantFP(N0)) return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0); - if (TLI.isNegatibleForFree(N0, DAG, LegalOperations, ForCodeSize)) - return TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); + if (SDValue NegN0 = + TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize)) + return NegN0; - // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0 FIXME: This is - // duplicated in isNegatibleForFree, but isNegatibleForFree doesn't know it - // was called from a context with a nsz flag if the input fsub does not. + // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0 + // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't + // know it was called from a context with a nsz flag if the input fsub does + // not. if (N0.getOpcode() == ISD::FSUB && (DAG.getTarget().Options.NoSignedZerosFPMath || N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) { @@ -13539,8 +14042,12 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) { } if (N1.hasOneUse()) { + // rebuildSetCC calls visitXor which may change the Chain when there is a + // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes. + HandleSDNode ChainHandle(Chain); if (SDValue NewN1 = rebuildSetCC(N1)) - return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain, NewN1, N2); + return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, + ChainHandle.getValue(), NewN1, N2); } return SDValue(); @@ -13592,8 +14099,8 @@ SDValue DAGCombiner::rebuildSetCC(SDValue N) { } } - // Transform br(xor(x, y)) -> br(x != y) - // Transform br(xor(xor(x,y), 1)) -> br (x == y) + // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne)) + // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq)) if (N.getOpcode() == ISD::XOR) { // Because we may call this on a speculatively constructed // SimplifiedSetCC Node, we need to simplify this node first. @@ -13617,16 +14124,17 @@ SDValue DAGCombiner::rebuildSetCC(SDValue N) { if (N.getOpcode() != ISD::XOR) return N; - SDNode *TheXor = N.getNode(); - - SDValue Op0 = TheXor->getOperand(0); - SDValue Op1 = TheXor->getOperand(1); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) { bool Equal = false; - if (isOneConstant(Op0) && Op0.hasOneUse() && - Op0.getOpcode() == ISD::XOR) { - TheXor = Op0.getNode(); + // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq)) + if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR && + Op0.getValueType() == MVT::i1) { + N = Op0; + Op0 = N->getOperand(0); + Op1 = N->getOperand(1); Equal = true; } @@ -13634,7 +14142,7 @@ SDValue DAGCombiner::rebuildSetCC(SDValue N) { if (LegalTypes) SetCCVT = getSetCCResultType(SetCCVT); // Replace the uses of XOR with SETCC - return DAG.getSetCC(SDLoc(TheXor), SetCCVT, Op0, Op1, + return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1, Equal ? ISD::SETEQ : ISD::SETNE); } } @@ -13994,118 +14502,142 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { return true; } -/// Try to combine a load/store with a add/sub of the base pointer node into a -/// post-indexed load/store. The transformation folded the add/subtract into the -/// new indexed load/store effectively and all of its uses are redirected to the -/// new load/store. -bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) { - if (Level < AfterLegalizeDAG) +static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse, + SDValue &BasePtr, SDValue &Offset, + ISD::MemIndexedMode &AM, + SelectionDAG &DAG, + const TargetLowering &TLI) { + if (PtrUse == N || + (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB)) return false; - bool IsLoad = true; - bool IsMasked = false; - SDValue Ptr; - if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad, IsMasked, - Ptr, TLI)) + if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG)) return false; - if (Ptr.getNode()->hasOneUse()) + // Don't create a indexed load / store with zero offset. + if (isNullConstant(Offset)) return false; - for (SDNode *Op : Ptr.getNode()->uses()) { - if (Op == N || - (Op->getOpcode() != ISD::ADD && Op->getOpcode() != ISD::SUB)) - continue; + if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr)) + return false; - SDValue BasePtr; - SDValue Offset; - ISD::MemIndexedMode AM = ISD::UNINDEXED; - if (TLI.getPostIndexedAddressParts(N, Op, BasePtr, Offset, AM, DAG)) { - // Don't create a indexed load / store with zero offset. - if (isNullConstant(Offset)) - continue; + SmallPtrSet<const SDNode *, 32> Visited; + for (SDNode *Use : BasePtr.getNode()->uses()) { + if (Use == Ptr.getNode()) + continue; - // Try turning it into a post-indexed load / store except when - // 1) All uses are load / store ops that use it as base ptr (and - // it may be folded as addressing mmode). - // 2) Op must be independent of N, i.e. Op is neither a predecessor - // nor a successor of N. Otherwise, if Op is folded that would - // create a cycle. + // No if there's a later user which could perform the index instead. + if (isa<MemSDNode>(Use)) { + bool IsLoad = true; + bool IsMasked = false; + SDValue OtherPtr; + if (getCombineLoadStoreParts(Use, ISD::POST_INC, ISD::POST_DEC, IsLoad, + IsMasked, OtherPtr, TLI)) { + SmallVector<const SDNode *, 2> Worklist; + Worklist.push_back(Use); + if (SDNode::hasPredecessorHelper(N, Visited, Worklist)) + return false; + } + } - if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr)) - continue; + // If all the uses are load / store addresses, then don't do the + // transformation. + if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) { + for (SDNode *UseUse : Use->uses()) + if (canFoldInAddressingMode(Use, UseUse, DAG, TLI)) + return false; + } + } + return true; +} - // Check for #1. - bool TryNext = false; - for (SDNode *Use : BasePtr.getNode()->uses()) { - if (Use == Ptr.getNode()) - continue; +static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad, + bool &IsMasked, SDValue &Ptr, + SDValue &BasePtr, SDValue &Offset, + ISD::MemIndexedMode &AM, + SelectionDAG &DAG, + const TargetLowering &TLI) { + if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad, + IsMasked, Ptr, TLI) || + Ptr.getNode()->hasOneUse()) + return nullptr; + + // Try turning it into a post-indexed load / store except when + // 1) All uses are load / store ops that use it as base ptr (and + // it may be folded as addressing mmode). + // 2) Op must be independent of N, i.e. Op is neither a predecessor + // nor a successor of N. Otherwise, if Op is folded that would + // create a cycle. + for (SDNode *Op : Ptr->uses()) { + // Check for #1. + if (!shouldCombineToPostInc(N, Ptr, Op, BasePtr, Offset, AM, DAG, TLI)) + continue; - // If all the uses are load / store addresses, then don't do the - // transformation. - if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) { - bool RealUse = false; - for (SDNode *UseUse : Use->uses()) { - if (!canFoldInAddressingMode(Use, UseUse, DAG, TLI)) - RealUse = true; - } + // Check for #2. + SmallPtrSet<const SDNode *, 32> Visited; + SmallVector<const SDNode *, 8> Worklist; + // Ptr is predecessor to both N and Op. + Visited.insert(Ptr.getNode()); + Worklist.push_back(N); + Worklist.push_back(Op); + if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) && + !SDNode::hasPredecessorHelper(Op, Visited, Worklist)) + return Op; + } + return nullptr; +} - if (!RealUse) { - TryNext = true; - break; - } - } - } +/// Try to combine a load/store with a add/sub of the base pointer node into a +/// post-indexed load/store. The transformation folded the add/subtract into the +/// new indexed load/store effectively and all of its uses are redirected to the +/// new load/store. +bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) { + if (Level < AfterLegalizeDAG) + return false; - if (TryNext) - continue; + bool IsLoad = true; + bool IsMasked = false; + SDValue Ptr; + SDValue BasePtr; + SDValue Offset; + ISD::MemIndexedMode AM = ISD::UNINDEXED; + SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr, + Offset, AM, DAG, TLI); + if (!Op) + return false; - // Check for #2. - SmallPtrSet<const SDNode *, 32> Visited; - SmallVector<const SDNode *, 8> Worklist; - // Ptr is predecessor to both N and Op. - Visited.insert(Ptr.getNode()); - Worklist.push_back(N); - Worklist.push_back(Op); - if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) && - !SDNode::hasPredecessorHelper(Op, Visited, Worklist)) { - SDValue Result; - if (!IsMasked) - Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, - Offset, AM) - : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), + SDValue Result; + if (!IsMasked) + Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, + Offset, AM) + : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), + BasePtr, Offset, AM); + else + Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), + BasePtr, Offset, AM) + : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM); - else - Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), - BasePtr, Offset, AM) - : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), - BasePtr, Offset, AM); - ++PostIndexedNodes; - ++NodesCombined; - LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); - dbgs() << "\nWith: "; Result.getNode()->dump(&DAG); - dbgs() << '\n'); - WorklistRemover DeadNodes(*this); - if (IsLoad) { - DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0)); - DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2)); - } else { - DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1)); - } - - // Finally, since the node is now dead, remove it from the graph. - deleteAndRecombine(N); - - // Replace the uses of Use with uses of the updated base value. - DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0), - Result.getValue(IsLoad ? 1 : 0)); - deleteAndRecombine(Op); - return true; - } - } + ++PostIndexedNodes; + ++NodesCombined; + LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); + dbgs() << "\nWith: "; Result.getNode()->dump(&DAG); + dbgs() << '\n'); + WorklistRemover DeadNodes(*this); + if (IsLoad) { + DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0)); + DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2)); + } else { + DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1)); } - return false; + // Finally, since the node is now dead, remove it from the graph. + deleteAndRecombine(N); + + // Replace the uses of Use with uses of the updated base value. + DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0), + Result.getValue(IsLoad ? 1 : 0)); + deleteAndRecombine(Op); + return true; } /// Return the base-pointer arithmetic from an indexed \p LD. @@ -14222,11 +14754,11 @@ SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) { auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue { if (LD->isIndexed()) { - bool IsSub = (LD->getAddressingMode() == ISD::PRE_DEC || - LD->getAddressingMode() == ISD::POST_DEC); - unsigned Opc = IsSub ? ISD::SUB : ISD::ADD; - SDValue Idx = DAG.getNode(Opc, SDLoc(LD), LD->getOperand(1).getValueType(), - LD->getOperand(1), LD->getOperand(2)); + // Cannot handle opaque target constants and we must respect the user's + // request not to split indexes from loads. + if (!canSplitIdx(LD)) + return SDValue(); + SDValue Idx = SplitIndexingFromLoad(LD); SDValue Ops[] = {Val, Idx, Chain}; return CombineTo(LD, Ops, 3); } @@ -14322,14 +14854,12 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { // the indexing into an add/sub directly (that TargetConstant may not be // valid for a different type of node, and we cannot convert an opaque // target constant into a regular constant). - bool HasOTCInc = LD->getOperand(2).getOpcode() == ISD::TargetConstant && - cast<ConstantSDNode>(LD->getOperand(2))->isOpaque(); + bool CanSplitIdx = canSplitIdx(LD); - if (!N->hasAnyUseOfValue(0) && - ((MaySplitLoadIndex && !HasOTCInc) || !N->hasAnyUseOfValue(1))) { + if (!N->hasAnyUseOfValue(0) && (CanSplitIdx || !N->hasAnyUseOfValue(1))) { SDValue Undef = DAG.getUNDEF(N->getValueType(0)); SDValue Index; - if (N->hasAnyUseOfValue(1) && MaySplitLoadIndex && !HasOTCInc) { + if (N->hasAnyUseOfValue(1) && CanSplitIdx) { Index = SplitIndexingFromLoad(LD); // Try to fold the base pointer arithmetic into subsequent loads and // stores. @@ -14356,11 +14886,12 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { // Try to infer better alignment information than the load already has. if (OptLevel != CodeGenOpt::None && LD->isUnindexed() && !LD->isAtomic()) { - if (unsigned Align = DAG.InferPtrAlignment(Ptr)) { - if (Align > LD->getAlignment() && LD->getSrcValueOffset() % Align == 0) { + if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) { + if (*Alignment > LD->getAlign() && + isAligned(*Alignment, LD->getSrcValueOffset())) { SDValue NewLoad = DAG.getExtLoad( LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr, - LD->getPointerInfo(), LD->getMemoryVT(), Align, + LD->getPointerInfo(), LD->getMemoryVT(), *Alignment, LD->getMemOperand()->getFlags(), LD->getAAInfo()); // NewLoad will always be N as we are only refining the alignment assert(NewLoad.getNode() == N); @@ -14557,11 +15088,11 @@ struct LoadedSlice { } /// Get the alignment of the load used for this slice. - unsigned getAlignment() const { - unsigned Alignment = Origin->getAlignment(); + Align getAlign() const { + Align Alignment = Origin->getAlign(); uint64_t Offset = getOffsetFromBase(); if (Offset != 0) - Alignment = MinAlign(Alignment, Alignment + Offset); + Alignment = commonAlignment(Alignment, Alignment.value() + Offset); return Alignment; } @@ -14657,8 +15188,8 @@ struct LoadedSlice { // Create the load for the slice. SDValue LastInst = DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr, - Origin->getPointerInfo().getWithOffset(Offset), - getAlignment(), Origin->getMemOperand()->getFlags()); + Origin->getPointerInfo().getWithOffset(Offset), getAlign(), + Origin->getMemOperand()->getFlags()); // If the final type is not the same as the loaded type, this means that // we have to pad with zero. Create a zero extend for that. EVT FinalType = Inst->getValueType(0); @@ -14699,10 +15230,10 @@ struct LoadedSlice { // Check if it will be merged with the load. // 1. Check the alignment constraint. - unsigned RequiredAlignment = DAG->getDataLayout().getABITypeAlignment( + Align RequiredAlignment = DAG->getDataLayout().getABITypeAlign( ResVT.getTypeForEVT(*DAG->getContext())); - if (RequiredAlignment > getAlignment()) + if (RequiredAlignment > getAlign()) return false; // 2. Check that the load is a legal operation for that type. @@ -14788,14 +15319,14 @@ static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices, continue; // Check if the target supplies paired loads for this type. - unsigned RequiredAlignment = 0; + Align RequiredAlignment; if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) { // move to the next pair, this type is hopeless. Second = nullptr; continue; } // Check if we meet the alignment requirement. - if (RequiredAlignment > First->getAlignment()) + if (First->getAlign() < RequiredAlignment) continue; // Check that both loads are next to each other in memory. @@ -14868,6 +15399,12 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) { !LD->getValueType(0).isInteger()) return false; + // The algorithm to split up a load of a scalable vector into individual + // elements currently requires knowing the length of the loaded type, + // so will need adjusting to work on scalable vectors. + if (LD->getValueType(0).isScalableVector()) + return false; + // Keep track of already used bits to detect overlapping values. // In that case, we will just abort the transformation. APInt UsedBits(LD->getValueSizeInBits(0), 0); @@ -15112,7 +15649,7 @@ SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) { // Y is known to provide just those bytes. If so, we try to replace the // load + replace + store sequence with a single (narrower) store, which makes // the load dead. - if (Opc == ISD::OR) { + if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) { std::pair<unsigned, unsigned> MaskedLoad; MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain); if (MaskedLoad.first) @@ -15128,6 +15665,9 @@ SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) { return NewST; } + if (!EnableReduceLoadOpStoreWidth) + return SDValue(); + if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) || Value.getOperand(1).getOpcode() != ISD::Constant) return SDValue(); @@ -15181,9 +15721,9 @@ SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) { if (DAG.getDataLayout().isBigEndian()) PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff; - unsigned NewAlign = MinAlign(LD->getAlignment(), PtrOff); + Align NewAlign = commonAlignment(LD->getAlign(), PtrOff); Type *NewVTTy = NewVT.getTypeForEVT(*DAG.getContext()); - if (NewAlign < DAG.getDataLayout().getABITypeAlignment(NewVTTy)) + if (NewAlign < DAG.getDataLayout().getABITypeAlign(NewVTTy)) return SDValue(); SDValue NewPtr = DAG.getMemBasePlusOffset(Ptr, PtrOff, SDLoc(LD)); @@ -15229,17 +15769,24 @@ SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) { ST->getPointerInfo().getAddrSpace() != 0) return SDValue(); - EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits()); + TypeSize VTSize = VT.getSizeInBits(); + + // We don't know the size of scalable types at compile time so we cannot + // create an integer of the equivalent size. + if (VTSize.isScalable()) + return SDValue(); + + EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VTSize.getFixedSize()); if (!TLI.isOperationLegal(ISD::LOAD, IntVT) || !TLI.isOperationLegal(ISD::STORE, IntVT) || !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) || !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT)) return SDValue(); - unsigned LDAlign = LD->getAlignment(); - unsigned STAlign = ST->getAlignment(); + Align LDAlign = LD->getAlign(); + Align STAlign = ST->getAlign(); Type *IntVTTy = IntVT.getTypeForEVT(*DAG.getContext()); - unsigned ABIAlign = DAG.getDataLayout().getABITypeAlignment(IntVTTy); + Align ABIAlign = DAG.getDataLayout().getABITypeAlign(IntVTTy); if (LDAlign < ABIAlign || STAlign < ABIAlign) return SDValue(); @@ -15356,7 +15903,7 @@ SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes, return DAG.getTokenFactor(StoreDL, Chains); } -bool DAGCombiner::MergeStoresOfConstantsOrVecElts( +bool DAGCombiner::mergeStoresOfConstantsOrVecElts( SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores, bool IsConstantSrc, bool UseVector, bool UseTrunc) { // Make sure we have something to merge. @@ -15530,14 +16077,12 @@ void DAGCombiner::getStoreMergeCandidates( if (BasePtr.getBase().isUndef()) return; - bool IsConstantSrc = isa<ConstantSDNode>(Val) || isa<ConstantFPSDNode>(Val); - bool IsExtractVecSrc = (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT || - Val.getOpcode() == ISD::EXTRACT_SUBVECTOR); - bool IsLoadSrc = isa<LoadSDNode>(Val); + StoreSource StoreSrc = getStoreSource(Val); + assert(StoreSrc != StoreSource::Unknown && "Expected known source for store"); BaseIndexOffset LBasePtr; // Match on loadbaseptr if relevant. EVT LoadVT; - if (IsLoadSrc) { + if (StoreSrc == StoreSource::Load) { auto *Ld = cast<LoadSDNode>(Val); LBasePtr = BaseIndexOffset::match(Ld, DAG); LoadVT = Ld->getMemoryVT(); @@ -15565,7 +16110,7 @@ void DAGCombiner::getStoreMergeCandidates( // Allow merging constants of different types as integers. bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT()) : Other->getMemoryVT() != MemVT; - if (IsLoadSrc) { + if (StoreSrc == StoreSource::Load) { if (NoTypeMatch) return false; // The Load's Base Ptr must also match @@ -15589,13 +16134,13 @@ void DAGCombiner::getStoreMergeCandidates( } else return false; } - if (IsConstantSrc) { + if (StoreSrc == StoreSource::Constant) { if (NoTypeMatch) return false; if (!(isa<ConstantSDNode>(OtherBC) || isa<ConstantFPSDNode>(OtherBC))) return false; } - if (IsExtractVecSrc) { + if (StoreSrc == StoreSource::Extract) { // Do not merge truncated stores here. if (Other->isTruncatingStore()) return false; @@ -15736,77 +16281,22 @@ bool DAGCombiner::checkMergeStoreCandidatesForDependencies( return true; } -bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { - if (OptLevel == CodeGenOpt::None || !EnableStoreMerging) - return false; - - EVT MemVT = St->getMemoryVT(); - int64_t ElementSizeBytes = MemVT.getStoreSize(); - unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1; - - if (MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits) - return false; - - bool NoVectors = DAG.getMachineFunction().getFunction().hasFnAttribute( - Attribute::NoImplicitFloat); - - // This function cannot currently deal with non-byte-sized memory sizes. - if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits()) - return false; - - if (!MemVT.isSimple()) - return false; - - // Perform an early exit check. Do not bother looking at stored values that - // are not constants, loads, or extracted vector elements. - SDValue StoredVal = peekThroughBitcasts(St->getValue()); - bool IsLoadSrc = isa<LoadSDNode>(StoredVal); - bool IsConstantSrc = isa<ConstantSDNode>(StoredVal) || - isa<ConstantFPSDNode>(StoredVal); - bool IsExtractVecSrc = (StoredVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT || - StoredVal.getOpcode() == ISD::EXTRACT_SUBVECTOR); - bool IsNonTemporalStore = St->isNonTemporal(); - bool IsNonTemporalLoad = - IsLoadSrc && cast<LoadSDNode>(StoredVal)->isNonTemporal(); - - if (!IsConstantSrc && !IsLoadSrc && !IsExtractVecSrc) - return false; - - SmallVector<MemOpLink, 8> StoreNodes; - SDNode *RootNode; - // Find potential store merge candidates by searching through chain sub-DAG - getStoreMergeCandidates(St, StoreNodes, RootNode); - - // Check if there is anything to merge. - if (StoreNodes.size() < 2) - return false; - - // Sort the memory operands according to their distance from the - // base pointer. - llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) { - return LHS.OffsetFromBase < RHS.OffsetFromBase; - }); - - // Store Merge attempts to merge the lowest stores. This generally - // works out as if successful, as the remaining stores are checked - // after the first collection of stores is merged. However, in the - // case that a non-mergeable store is found first, e.g., {p[-2], - // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent - // mergeable cases. To prevent this, we prune such stores from the - // front of StoreNodes here. - - bool RV = false; - while (StoreNodes.size() > 1) { +unsigned +DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes, + int64_t ElementSizeBytes) const { + while (true) { + // Find a store past the width of the first store. size_t StartIdx = 0; while ((StartIdx + 1 < StoreNodes.size()) && StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes != - StoreNodes[StartIdx + 1].OffsetFromBase) + StoreNodes[StartIdx + 1].OffsetFromBase) ++StartIdx; // Bail if we don't have enough candidates to merge. if (StartIdx + 1 >= StoreNodes.size()) - return RV; + return 0; + // Trim stores that overlapped with the first store. if (StartIdx) StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx); @@ -15822,302 +16312,345 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { break; NumConsecutiveStores = i + 1; } + if (NumConsecutiveStores > 1) + return NumConsecutiveStores; - if (NumConsecutiveStores < 2) { - StoreNodes.erase(StoreNodes.begin(), - StoreNodes.begin() + NumConsecutiveStores); - continue; - } - - // The node with the lowest store address. - LLVMContext &Context = *DAG.getContext(); - const DataLayout &DL = DAG.getDataLayout(); - - // Store the constants into memory as one consecutive store. - if (IsConstantSrc) { - while (NumConsecutiveStores >= 2) { - LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; - unsigned FirstStoreAS = FirstInChain->getAddressSpace(); - unsigned FirstStoreAlign = FirstInChain->getAlignment(); - unsigned LastLegalType = 1; - unsigned LastLegalVectorType = 1; - bool LastIntegerTrunc = false; - bool NonZero = false; - unsigned FirstZeroAfterNonZero = NumConsecutiveStores; - for (unsigned i = 0; i < NumConsecutiveStores; ++i) { - StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode); - SDValue StoredVal = ST->getValue(); - bool IsElementZero = false; - if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal)) - IsElementZero = C->isNullValue(); - else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal)) - IsElementZero = C->getConstantFPValue()->isNullValue(); - if (IsElementZero) { - if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores) - FirstZeroAfterNonZero = i; - } - NonZero |= !IsElementZero; - - // Find a legal type for the constant store. - unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8; - EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits); - bool IsFast = false; + // There are no consecutive stores at the start of the list. + // Remove the first store and try again. + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1); + } +} - // Break early when size is too large to be legal. - if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits) - break; +bool DAGCombiner::tryStoreMergeOfConstants( + SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores, + EVT MemVT, SDNode *RootNode, bool AllowVectors) { + LLVMContext &Context = *DAG.getContext(); + const DataLayout &DL = DAG.getDataLayout(); + int64_t ElementSizeBytes = MemVT.getStoreSize(); + unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1; + bool MadeChange = false; + + // Store the constants into memory as one consecutive store. + while (NumConsecutiveStores >= 2) { + LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; + unsigned FirstStoreAS = FirstInChain->getAddressSpace(); + unsigned FirstStoreAlign = FirstInChain->getAlignment(); + unsigned LastLegalType = 1; + unsigned LastLegalVectorType = 1; + bool LastIntegerTrunc = false; + bool NonZero = false; + unsigned FirstZeroAfterNonZero = NumConsecutiveStores; + for (unsigned i = 0; i < NumConsecutiveStores; ++i) { + StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode); + SDValue StoredVal = ST->getValue(); + bool IsElementZero = false; + if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal)) + IsElementZero = C->isNullValue(); + else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal)) + IsElementZero = C->getConstantFPValue()->isNullValue(); + if (IsElementZero) { + if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores) + FirstZeroAfterNonZero = i; + } + NonZero |= !IsElementZero; - if (TLI.isTypeLegal(StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && - TLI.allowsMemoryAccess(Context, DL, StoreTy, - *FirstInChain->getMemOperand(), &IsFast) && - IsFast) { - LastIntegerTrunc = false; - LastLegalType = i + 1; - // Or check whether a truncstore is legal. - } else if (TLI.getTypeAction(Context, StoreTy) == - TargetLowering::TypePromoteInteger) { - EVT LegalizedStoredValTy = - TLI.getTypeToTransformTo(Context, StoredVal.getValueType()); - if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) && - TLI.allowsMemoryAccess(Context, DL, StoreTy, - *FirstInChain->getMemOperand(), - &IsFast) && - IsFast) { - LastIntegerTrunc = true; - LastLegalType = i + 1; - } - } + // Find a legal type for the constant store. + unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8; + EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits); + bool IsFast = false; - // We only use vectors if the constant is known to be zero or the - // target allows it and the function is not marked with the - // noimplicitfloat attribute. - if ((!NonZero || - TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) && - !NoVectors) { - // Find a legal type for the vector store. - unsigned Elts = (i + 1) * NumMemElts; - EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); - if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) && - TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) && - TLI.allowsMemoryAccess( - Context, DL, Ty, *FirstInChain->getMemOperand(), &IsFast) && - IsFast) - LastLegalVectorType = i + 1; - } - } + // Break early when size is too large to be legal. + if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits) + break; - bool UseVector = (LastLegalVectorType > LastLegalType) && !NoVectors; - unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType; - - // Check if we found a legal integer type that creates a meaningful - // merge. - if (NumElem < 2) { - // We know that candidate stores are in order and of correct - // shape. While there is no mergeable sequence from the - // beginning one may start later in the sequence. The only - // reason a merge of size N could have failed where another of - // the same size would not have, is if the alignment has - // improved or we've dropped a non-zero value. Drop as many - // candidates as we can here. - unsigned NumSkip = 1; - while ( - (NumSkip < NumConsecutiveStores) && - (NumSkip < FirstZeroAfterNonZero) && - (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) - NumSkip++; - - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); - NumConsecutiveStores -= NumSkip; - continue; + if (TLI.isTypeLegal(StoreTy) && + TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && + TLI.allowsMemoryAccess(Context, DL, StoreTy, + *FirstInChain->getMemOperand(), &IsFast) && + IsFast) { + LastIntegerTrunc = false; + LastLegalType = i + 1; + // Or check whether a truncstore is legal. + } else if (TLI.getTypeAction(Context, StoreTy) == + TargetLowering::TypePromoteInteger) { + EVT LegalizedStoredValTy = + TLI.getTypeToTransformTo(Context, StoredVal.getValueType()); + if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) && + TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) && + TLI.allowsMemoryAccess(Context, DL, StoreTy, + *FirstInChain->getMemOperand(), &IsFast) && + IsFast) { + LastIntegerTrunc = true; + LastLegalType = i + 1; } + } - // Check that we can merge these candidates without causing a cycle. - if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem, - RootNode)) { - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); - NumConsecutiveStores -= NumElem; - continue; - } + // We only use vectors if the constant is known to be zero or the + // target allows it and the function is not marked with the + // noimplicitfloat attribute. + if ((!NonZero || + TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) && + AllowVectors) { + // Find a legal type for the vector store. + unsigned Elts = (i + 1) * NumMemElts; + EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); + if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) && + TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) && + TLI.allowsMemoryAccess(Context, DL, Ty, + *FirstInChain->getMemOperand(), &IsFast) && + IsFast) + LastLegalVectorType = i + 1; + } + } - RV |= MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem, true, - UseVector, LastIntegerTrunc); + bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors; + unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType; + + // Check if we found a legal integer type that creates a meaningful + // merge. + if (NumElem < 2) { + // We know that candidate stores are in order and of correct + // shape. While there is no mergeable sequence from the + // beginning one may start later in the sequence. The only + // reason a merge of size N could have failed where another of + // the same size would not have, is if the alignment has + // improved or we've dropped a non-zero value. Drop as many + // candidates as we can here. + unsigned NumSkip = 1; + while ((NumSkip < NumConsecutiveStores) && + (NumSkip < FirstZeroAfterNonZero) && + (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) + NumSkip++; + + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); + NumConsecutiveStores -= NumSkip; + continue; + } - // Remove merged stores for next iteration. - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); - NumConsecutiveStores -= NumElem; - } + // Check that we can merge these candidates without causing a cycle. + if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem, + RootNode)) { + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); + NumConsecutiveStores -= NumElem; continue; } - // When extracting multiple vector elements, try to store them - // in one vector store rather than a sequence of scalar stores. - if (IsExtractVecSrc) { - // Loop on Consecutive Stores on success. - while (NumConsecutiveStores >= 2) { - LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; - unsigned FirstStoreAS = FirstInChain->getAddressSpace(); - unsigned FirstStoreAlign = FirstInChain->getAlignment(); - unsigned NumStoresToMerge = 1; - for (unsigned i = 0; i < NumConsecutiveStores; ++i) { - // Find a legal type for the vector store. - unsigned Elts = (i + 1) * NumMemElts; - EVT Ty = - EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); - bool IsFast; - - // Break early when size is too large to be legal. - if (Ty.getSizeInBits() > MaximumLegalStoreInBits) - break; + MadeChange |= mergeStoresOfConstantsOrVecElts( + StoreNodes, MemVT, NumElem, true, UseVector, LastIntegerTrunc); - if (TLI.isTypeLegal(Ty) && - TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) && - TLI.allowsMemoryAccess(Context, DL, Ty, - *FirstInChain->getMemOperand(), &IsFast) && - IsFast) - NumStoresToMerge = i + 1; - } + // Remove merged stores for next iteration. + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); + NumConsecutiveStores -= NumElem; + } + return MadeChange; +} - // Check if we found a legal integer type creating a meaningful - // merge. - if (NumStoresToMerge < 2) { - // We know that candidate stores are in order and of correct - // shape. While there is no mergeable sequence from the - // beginning one may start later in the sequence. The only - // reason a merge of size N could have failed where another of - // the same size would not have, is if the alignment has - // improved. Drop as many candidates as we can here. - unsigned NumSkip = 1; - while ( - (NumSkip < NumConsecutiveStores) && - (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) - NumSkip++; - - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); - NumConsecutiveStores -= NumSkip; - continue; - } +bool DAGCombiner::tryStoreMergeOfExtracts( + SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores, + EVT MemVT, SDNode *RootNode) { + LLVMContext &Context = *DAG.getContext(); + const DataLayout &DL = DAG.getDataLayout(); + unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1; + bool MadeChange = false; + + // Loop on Consecutive Stores on success. + while (NumConsecutiveStores >= 2) { + LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; + unsigned FirstStoreAS = FirstInChain->getAddressSpace(); + unsigned FirstStoreAlign = FirstInChain->getAlignment(); + unsigned NumStoresToMerge = 1; + for (unsigned i = 0; i < NumConsecutiveStores; ++i) { + // Find a legal type for the vector store. + unsigned Elts = (i + 1) * NumMemElts; + EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); + bool IsFast = false; - // Check that we can merge these candidates without causing a cycle. - if (!checkMergeStoreCandidatesForDependencies( - StoreNodes, NumStoresToMerge, RootNode)) { - StoreNodes.erase(StoreNodes.begin(), - StoreNodes.begin() + NumStoresToMerge); - NumConsecutiveStores -= NumStoresToMerge; - continue; - } + // Break early when size is too large to be legal. + if (Ty.getSizeInBits() > MaximumLegalStoreInBits) + break; - RV |= MergeStoresOfConstantsOrVecElts( - StoreNodes, MemVT, NumStoresToMerge, false, true, false); + if (TLI.isTypeLegal(Ty) && TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) && + TLI.allowsMemoryAccess(Context, DL, Ty, + *FirstInChain->getMemOperand(), &IsFast) && + IsFast) + NumStoresToMerge = i + 1; + } + + // Check if we found a legal integer type creating a meaningful + // merge. + if (NumStoresToMerge < 2) { + // We know that candidate stores are in order and of correct + // shape. While there is no mergeable sequence from the + // beginning one may start later in the sequence. The only + // reason a merge of size N could have failed where another of + // the same size would not have, is if the alignment has + // improved. Drop as many candidates as we can here. + unsigned NumSkip = 1; + while ((NumSkip < NumConsecutiveStores) && + (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) + NumSkip++; + + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); + NumConsecutiveStores -= NumSkip; + continue; + } - StoreNodes.erase(StoreNodes.begin(), - StoreNodes.begin() + NumStoresToMerge); - NumConsecutiveStores -= NumStoresToMerge; - } + // Check that we can merge these candidates without causing a cycle. + if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStoresToMerge, + RootNode)) { + StoreNodes.erase(StoreNodes.begin(), + StoreNodes.begin() + NumStoresToMerge); + NumConsecutiveStores -= NumStoresToMerge; continue; } - // Below we handle the case of multiple consecutive stores that - // come from multiple consecutive loads. We merge them into a single - // wide load and a single wide store. + MadeChange |= mergeStoresOfConstantsOrVecElts( + StoreNodes, MemVT, NumStoresToMerge, false, true, false); + + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumStoresToMerge); + NumConsecutiveStores -= NumStoresToMerge; + } + return MadeChange; +} + +bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes, + unsigned NumConsecutiveStores, EVT MemVT, + SDNode *RootNode, bool AllowVectors, + bool IsNonTemporalStore, + bool IsNonTemporalLoad) { + LLVMContext &Context = *DAG.getContext(); + const DataLayout &DL = DAG.getDataLayout(); + int64_t ElementSizeBytes = MemVT.getStoreSize(); + unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1; + bool MadeChange = false; - // Look for load nodes which are used by the stored values. - SmallVector<MemOpLink, 8> LoadNodes; + int64_t StartAddress = StoreNodes[0].OffsetFromBase; - // Find acceptable loads. Loads need to have the same chain (token factor), - // must not be zext, volatile, indexed, and they must be consecutive. - BaseIndexOffset LdBasePtr; + // Look for load nodes which are used by the stored values. + SmallVector<MemOpLink, 8> LoadNodes; - for (unsigned i = 0; i < NumConsecutiveStores; ++i) { - StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); - SDValue Val = peekThroughBitcasts(St->getValue()); - LoadSDNode *Ld = cast<LoadSDNode>(Val); - - BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG); - // If this is not the first ptr that we check. - int64_t LdOffset = 0; - if (LdBasePtr.getBase().getNode()) { - // The base ptr must be the same. - if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset)) - break; - } else { - // Check that all other base pointers are the same as this one. - LdBasePtr = LdPtr; - } + // Find acceptable loads. Loads need to have the same chain (token factor), + // must not be zext, volatile, indexed, and they must be consecutive. + BaseIndexOffset LdBasePtr; + + for (unsigned i = 0; i < NumConsecutiveStores; ++i) { + StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode); + SDValue Val = peekThroughBitcasts(St->getValue()); + LoadSDNode *Ld = cast<LoadSDNode>(Val); - // We found a potential memory operand to merge. - LoadNodes.push_back(MemOpLink(Ld, LdOffset)); + BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG); + // If this is not the first ptr that we check. + int64_t LdOffset = 0; + if (LdBasePtr.getBase().getNode()) { + // The base ptr must be the same. + if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset)) + break; + } else { + // Check that all other base pointers are the same as this one. + LdBasePtr = LdPtr; } - while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) { + // We found a potential memory operand to merge. + LoadNodes.push_back(MemOpLink(Ld, LdOffset)); + } + + while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) { + Align RequiredAlignment; + bool NeedRotate = false; + if (LoadNodes.size() == 2) { // If we have load/store pair instructions and we only have two values, // don't bother merging. - unsigned RequiredAlignment; - if (LoadNodes.size() == 2 && - TLI.hasPairedLoad(MemVT, RequiredAlignment) && - StoreNodes[0].MemNode->getAlignment() >= RequiredAlignment) { + if (TLI.hasPairedLoad(MemVT, RequiredAlignment) && + StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) { StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2); LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2); break; } - LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; - unsigned FirstStoreAS = FirstInChain->getAddressSpace(); - unsigned FirstStoreAlign = FirstInChain->getAlignment(); - LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode); - unsigned FirstLoadAlign = FirstLoad->getAlignment(); - - // Scan the memory operations on the chain and find the first - // non-consecutive load memory address. These variables hold the index in - // the store node array. - - unsigned LastConsecutiveLoad = 1; - - // This variable refers to the size and not index in the array. - unsigned LastLegalVectorType = 1; - unsigned LastLegalIntegerType = 1; - bool isDereferenceable = true; - bool DoIntegerTruncate = false; - StartAddress = LoadNodes[0].OffsetFromBase; - SDValue FirstChain = FirstLoad->getChain(); - for (unsigned i = 1; i < LoadNodes.size(); ++i) { - // All loads must share the same chain. - if (LoadNodes[i].MemNode->getChain() != FirstChain) - break; + // If the loads are reversed, see if we can rotate the halves into place. + int64_t Offset0 = LoadNodes[0].OffsetFromBase; + int64_t Offset1 = LoadNodes[1].OffsetFromBase; + EVT PairVT = EVT::getIntegerVT(Context, ElementSizeBytes * 8 * 2); + if (Offset0 - Offset1 == ElementSizeBytes && + (hasOperation(ISD::ROTL, PairVT) || + hasOperation(ISD::ROTR, PairVT))) { + std::swap(LoadNodes[0], LoadNodes[1]); + NeedRotate = true; + } + } + LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; + unsigned FirstStoreAS = FirstInChain->getAddressSpace(); + unsigned FirstStoreAlign = FirstInChain->getAlignment(); + LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode); - int64_t CurrAddress = LoadNodes[i].OffsetFromBase; - if (CurrAddress - StartAddress != (ElementSizeBytes * i)) - break; - LastConsecutiveLoad = i; + // Scan the memory operations on the chain and find the first + // non-consecutive load memory address. These variables hold the index in + // the store node array. + + unsigned LastConsecutiveLoad = 1; + + // This variable refers to the size and not index in the array. + unsigned LastLegalVectorType = 1; + unsigned LastLegalIntegerType = 1; + bool isDereferenceable = true; + bool DoIntegerTruncate = false; + StartAddress = LoadNodes[0].OffsetFromBase; + SDValue LoadChain = FirstLoad->getChain(); + for (unsigned i = 1; i < LoadNodes.size(); ++i) { + // All loads must share the same chain. + if (LoadNodes[i].MemNode->getChain() != LoadChain) + break; - if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable()) - isDereferenceable = false; + int64_t CurrAddress = LoadNodes[i].OffsetFromBase; + if (CurrAddress - StartAddress != (ElementSizeBytes * i)) + break; + LastConsecutiveLoad = i; - // Find a legal type for the vector store. - unsigned Elts = (i + 1) * NumMemElts; - EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); + if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable()) + isDereferenceable = false; - // Break early when size is too large to be legal. - if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits) - break; + // Find a legal type for the vector store. + unsigned Elts = (i + 1) * NumMemElts; + EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); - bool IsFastSt, IsFastLd; - if (TLI.isTypeLegal(StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && - TLI.allowsMemoryAccess(Context, DL, StoreTy, - *FirstInChain->getMemOperand(), &IsFastSt) && - IsFastSt && - TLI.allowsMemoryAccess(Context, DL, StoreTy, - *FirstLoad->getMemOperand(), &IsFastLd) && - IsFastLd) { - LastLegalVectorType = i + 1; - } + // Break early when size is too large to be legal. + if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits) + break; - // Find a legal type for the integer store. - unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8; - StoreTy = EVT::getIntegerVT(Context, SizeInBits); - if (TLI.isTypeLegal(StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && + bool IsFastSt = false; + bool IsFastLd = false; + if (TLI.isTypeLegal(StoreTy) && + TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && + TLI.allowsMemoryAccess(Context, DL, StoreTy, + *FirstInChain->getMemOperand(), &IsFastSt) && + IsFastSt && + TLI.allowsMemoryAccess(Context, DL, StoreTy, + *FirstLoad->getMemOperand(), &IsFastLd) && + IsFastLd) { + LastLegalVectorType = i + 1; + } + + // Find a legal type for the integer store. + unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8; + StoreTy = EVT::getIntegerVT(Context, SizeInBits); + if (TLI.isTypeLegal(StoreTy) && + TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) && + TLI.allowsMemoryAccess(Context, DL, StoreTy, + *FirstInChain->getMemOperand(), &IsFastSt) && + IsFastSt && + TLI.allowsMemoryAccess(Context, DL, StoreTy, + *FirstLoad->getMemOperand(), &IsFastLd) && + IsFastLd) { + LastLegalIntegerType = i + 1; + DoIntegerTruncate = false; + // Or check whether a truncstore and extload is legal. + } else if (TLI.getTypeAction(Context, StoreTy) == + TargetLowering::TypePromoteInteger) { + EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy); + if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) && + TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) && + TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) && + TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) && + TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) && TLI.allowsMemoryAccess(Context, DL, StoreTy, *FirstInChain->getMemOperand(), &IsFastSt) && IsFastSt && @@ -16125,149 +16658,225 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) { *FirstLoad->getMemOperand(), &IsFastLd) && IsFastLd) { LastLegalIntegerType = i + 1; - DoIntegerTruncate = false; - // Or check whether a truncstore and extload is legal. - } else if (TLI.getTypeAction(Context, StoreTy) == - TargetLowering::TypePromoteInteger) { - EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy); - if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) && - TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG) && - TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, - StoreTy) && - TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, - StoreTy) && - TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) && - TLI.allowsMemoryAccess(Context, DL, StoreTy, - *FirstInChain->getMemOperand(), - &IsFastSt) && - IsFastSt && - TLI.allowsMemoryAccess(Context, DL, StoreTy, - *FirstLoad->getMemOperand(), &IsFastLd) && - IsFastLd) { - LastLegalIntegerType = i + 1; - DoIntegerTruncate = true; - } + DoIntegerTruncate = true; } } + } - // Only use vector types if the vector type is larger than the integer - // type. If they are the same, use integers. - bool UseVectorTy = - LastLegalVectorType > LastLegalIntegerType && !NoVectors; - unsigned LastLegalType = - std::max(LastLegalVectorType, LastLegalIntegerType); - - // We add +1 here because the LastXXX variables refer to location while - // the NumElem refers to array/index size. - unsigned NumElem = - std::min(NumConsecutiveStores, LastConsecutiveLoad + 1); - NumElem = std::min(LastLegalType, NumElem); - - if (NumElem < 2) { - // We know that candidate stores are in order and of correct - // shape. While there is no mergeable sequence from the - // beginning one may start later in the sequence. The only - // reason a merge of size N could have failed where another of - // the same size would not have is if the alignment or either - // the load or store has improved. Drop as many candidates as we - // can here. - unsigned NumSkip = 1; - while ((NumSkip < LoadNodes.size()) && - (LoadNodes[NumSkip].MemNode->getAlignment() <= FirstLoadAlign) && - (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) - NumSkip++; - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); - LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip); - NumConsecutiveStores -= NumSkip; - continue; - } + // Only use vector types if the vector type is larger than the integer + // type. If they are the same, use integers. + bool UseVectorTy = + LastLegalVectorType > LastLegalIntegerType && AllowVectors; + unsigned LastLegalType = + std::max(LastLegalVectorType, LastLegalIntegerType); + + // We add +1 here because the LastXXX variables refer to location while + // the NumElem refers to array/index size. + unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1); + NumElem = std::min(LastLegalType, NumElem); + unsigned FirstLoadAlign = FirstLoad->getAlignment(); + + if (NumElem < 2) { + // We know that candidate stores are in order and of correct + // shape. While there is no mergeable sequence from the + // beginning one may start later in the sequence. The only + // reason a merge of size N could have failed where another of + // the same size would not have is if the alignment or either + // the load or store has improved. Drop as many candidates as we + // can here. + unsigned NumSkip = 1; + while ((NumSkip < LoadNodes.size()) && + (LoadNodes[NumSkip].MemNode->getAlignment() <= FirstLoadAlign) && + (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) + NumSkip++; + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); + LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip); + NumConsecutiveStores -= NumSkip; + continue; + } - // Check that we can merge these candidates without causing a cycle. - if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem, - RootNode)) { - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); - LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem); - NumConsecutiveStores -= NumElem; - continue; - } + // Check that we can merge these candidates without causing a cycle. + if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem, + RootNode)) { + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); + LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem); + NumConsecutiveStores -= NumElem; + continue; + } - // Find if it is better to use vectors or integers to load and store - // to memory. - EVT JointMemOpVT; - if (UseVectorTy) { - // Find a legal type for the vector store. - unsigned Elts = NumElem * NumMemElts; - JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); - } else { - unsigned SizeInBits = NumElem * ElementSizeBytes * 8; - JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits); + // Find if it is better to use vectors or integers to load and store + // to memory. + EVT JointMemOpVT; + if (UseVectorTy) { + // Find a legal type for the vector store. + unsigned Elts = NumElem * NumMemElts; + JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts); + } else { + unsigned SizeInBits = NumElem * ElementSizeBytes * 8; + JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits); + } + + SDLoc LoadDL(LoadNodes[0].MemNode); + SDLoc StoreDL(StoreNodes[0].MemNode); + + // The merged loads are required to have the same incoming chain, so + // using the first's chain is acceptable. + + SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem); + AddToWorklist(NewStoreChain.getNode()); + + MachineMemOperand::Flags LdMMOFlags = + isDereferenceable ? MachineMemOperand::MODereferenceable + : MachineMemOperand::MONone; + if (IsNonTemporalLoad) + LdMMOFlags |= MachineMemOperand::MONonTemporal; + + MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore + ? MachineMemOperand::MONonTemporal + : MachineMemOperand::MONone; + + SDValue NewLoad, NewStore; + if (UseVectorTy || !DoIntegerTruncate) { + NewLoad = DAG.getLoad( + JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(), + FirstLoad->getPointerInfo(), FirstLoadAlign, LdMMOFlags); + SDValue StoreOp = NewLoad; + if (NeedRotate) { + unsigned LoadWidth = ElementSizeBytes * 8 * 2; + assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) && + "Unexpected type for rotate-able load pair"); + SDValue RotAmt = + DAG.getShiftAmountConstant(LoadWidth / 2, JointMemOpVT, LoadDL); + // Target can convert to the identical ROTR if it does not have ROTL. + StoreOp = DAG.getNode(ISD::ROTL, LoadDL, JointMemOpVT, NewLoad, RotAmt); } + NewStore = DAG.getStore( + NewStoreChain, StoreDL, StoreOp, FirstInChain->getBasePtr(), + FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags); + } else { // This must be the truncstore/extload case + EVT ExtendedTy = + TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT); + NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy, + FirstLoad->getChain(), FirstLoad->getBasePtr(), + FirstLoad->getPointerInfo(), JointMemOpVT, + FirstLoadAlign, LdMMOFlags); + NewStore = DAG.getTruncStore(NewStoreChain, StoreDL, NewLoad, + FirstInChain->getBasePtr(), + FirstInChain->getPointerInfo(), JointMemOpVT, + FirstInChain->getAlignment(), + FirstInChain->getMemOperand()->getFlags()); + } + + // Transfer chain users from old loads to the new load. + for (unsigned i = 0; i < NumElem; ++i) { + LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), + SDValue(NewLoad.getNode(), 1)); + } + + // Replace all stores with the new store. Recursively remove corresponding + // values if they are no longer used. + for (unsigned i = 0; i < NumElem; ++i) { + SDValue Val = StoreNodes[i].MemNode->getOperand(1); + CombineTo(StoreNodes[i].MemNode, NewStore); + if (Val.getNode()->use_empty()) + recursivelyDeleteUnusedNodes(Val.getNode()); + } + + MadeChange = true; + StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); + LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem); + NumConsecutiveStores -= NumElem; + } + return MadeChange; +} + +bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) { + if (OptLevel == CodeGenOpt::None || !EnableStoreMerging) + return false; - SDLoc LoadDL(LoadNodes[0].MemNode); - SDLoc StoreDL(StoreNodes[0].MemNode); - - // The merged loads are required to have the same incoming chain, so - // using the first's chain is acceptable. - - SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem); - AddToWorklist(NewStoreChain.getNode()); - - MachineMemOperand::Flags LdMMOFlags = - isDereferenceable ? MachineMemOperand::MODereferenceable - : MachineMemOperand::MONone; - if (IsNonTemporalLoad) - LdMMOFlags |= MachineMemOperand::MONonTemporal; - - MachineMemOperand::Flags StMMOFlags = - IsNonTemporalStore ? MachineMemOperand::MONonTemporal - : MachineMemOperand::MONone; - - SDValue NewLoad, NewStore; - if (UseVectorTy || !DoIntegerTruncate) { - NewLoad = - DAG.getLoad(JointMemOpVT, LoadDL, FirstLoad->getChain(), - FirstLoad->getBasePtr(), FirstLoad->getPointerInfo(), - FirstLoadAlign, LdMMOFlags); - NewStore = DAG.getStore( - NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(), - FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags); - } else { // This must be the truncstore/extload case - EVT ExtendedTy = - TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT); - NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy, - FirstLoad->getChain(), FirstLoad->getBasePtr(), - FirstLoad->getPointerInfo(), JointMemOpVT, - FirstLoadAlign, LdMMOFlags); - NewStore = DAG.getTruncStore(NewStoreChain, StoreDL, NewLoad, - FirstInChain->getBasePtr(), - FirstInChain->getPointerInfo(), - JointMemOpVT, FirstInChain->getAlignment(), - FirstInChain->getMemOperand()->getFlags()); - } + // TODO: Extend this function to merge stores of scalable vectors. + // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8> + // store since we know <vscale x 16 x i8> is exactly twice as large as + // <vscale x 8 x i8>). Until then, bail out for scalable vectors. + EVT MemVT = St->getMemoryVT(); + if (MemVT.isScalableVector()) + return false; + if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits) + return false; - // Transfer chain users from old loads to the new load. - for (unsigned i = 0; i < NumElem; ++i) { - LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), - SDValue(NewLoad.getNode(), 1)); - } + // This function cannot currently deal with non-byte-sized memory sizes. + int64_t ElementSizeBytes = MemVT.getStoreSize(); + if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits()) + return false; - // Replace the all stores with the new store. Recursively remove - // corresponding value if its no longer used. - for (unsigned i = 0; i < NumElem; ++i) { - SDValue Val = StoreNodes[i].MemNode->getOperand(1); - CombineTo(StoreNodes[i].MemNode, NewStore); - if (Val.getNode()->use_empty()) - recursivelyDeleteUnusedNodes(Val.getNode()); - } + // Do not bother looking at stored values that are not constants, loads, or + // extracted vector elements. + SDValue StoredVal = peekThroughBitcasts(St->getValue()); + const StoreSource StoreSrc = getStoreSource(StoredVal); + if (StoreSrc == StoreSource::Unknown) + return false; - RV = true; - StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem); - LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem); - NumConsecutiveStores -= NumElem; + SmallVector<MemOpLink, 8> StoreNodes; + SDNode *RootNode; + // Find potential store merge candidates by searching through chain sub-DAG + getStoreMergeCandidates(St, StoreNodes, RootNode); + + // Check if there is anything to merge. + if (StoreNodes.size() < 2) + return false; + + // Sort the memory operands according to their distance from the + // base pointer. + llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) { + return LHS.OffsetFromBase < RHS.OffsetFromBase; + }); + + bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute( + Attribute::NoImplicitFloat); + bool IsNonTemporalStore = St->isNonTemporal(); + bool IsNonTemporalLoad = StoreSrc == StoreSource::Load && + cast<LoadSDNode>(StoredVal)->isNonTemporal(); + + // Store Merge attempts to merge the lowest stores. This generally + // works out as if successful, as the remaining stores are checked + // after the first collection of stores is merged. However, in the + // case that a non-mergeable store is found first, e.g., {p[-2], + // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent + // mergeable cases. To prevent this, we prune such stores from the + // front of StoreNodes here. + bool MadeChange = false; + while (StoreNodes.size() > 1) { + unsigned NumConsecutiveStores = + getConsecutiveStores(StoreNodes, ElementSizeBytes); + // There are no more stores in the list to examine. + if (NumConsecutiveStores == 0) + return MadeChange; + + // We have at least 2 consecutive stores. Try to merge them. + assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores"); + switch (StoreSrc) { + case StoreSource::Constant: + MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores, + MemVT, RootNode, AllowVectors); + break; + + case StoreSource::Extract: + MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores, + MemVT, RootNode); + break; + + case StoreSource::Load: + MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores, + MemVT, RootNode, AllowVectors, + IsNonTemporalStore, IsNonTemporalLoad); + break; + + default: + llvm_unreachable("Unhandled store source type"); } } - return RV; + return MadeChange; } SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) { @@ -16408,11 +17017,12 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { // Try to infer better alignment information than the store already has. if (OptLevel != CodeGenOpt::None && ST->isUnindexed() && !ST->isAtomic()) { - if (unsigned Align = DAG.InferPtrAlignment(Ptr)) { - if (Align > ST->getAlignment() && ST->getSrcValueOffset() % Align == 0) { + if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) { + if (*Alignment > ST->getAlign() && + isAligned(*Alignment, ST->getSrcValueOffset())) { SDValue NewStore = DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(), - ST->getMemoryVT(), Align, + ST->getMemoryVT(), *Alignment, ST->getMemOperand()->getFlags(), ST->getAAInfo()); // NewStore will always be N as we are only refining the alignment assert(NewStore.getNode() == N); @@ -16497,7 +17107,10 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { } if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() && - !ST1->getBasePtr().isUndef()) { + !ST1->getBasePtr().isUndef() && + // BaseIndexOffset and the code below requires knowing the size + // of a vector, so bail out if MemoryVT is scalable. + !ST1->getMemoryVT().isScalableVector()) { const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG); const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG); unsigned STBitSize = ST->getMemoryVT().getSizeInBits(); @@ -16510,33 +17123,6 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { CombineTo(ST1, ST1->getChain()); return SDValue(); } - - // If ST stores to a subset of preceding store's write set, we may be - // able to fold ST's value into the preceding stored value. As we know - // the other uses of ST1's chain are unconcerned with ST, this folding - // will not affect those nodes. - int64_t BitOffset; - if (ChainBase.contains(DAG, ChainBitSize, STBase, STBitSize, - BitOffset)) { - SDValue ChainValue = ST1->getValue(); - if (auto *C1 = dyn_cast<ConstantSDNode>(ChainValue)) { - if (auto *C = dyn_cast<ConstantSDNode>(Value)) { - APInt Val = C1->getAPIntValue(); - APInt InsertVal = C->getAPIntValue().zextOrTrunc(STBitSize); - // FIXME: Handle Big-endian mode. - if (!DAG.getDataLayout().isBigEndian()) { - Val.insertBits(InsertVal, BitOffset); - SDValue NewSDVal = - DAG.getConstant(Val, SDLoc(C), ChainValue.getValueType(), - C1->isTargetOpcode(), C1->isOpaque()); - SDNode *NewST1 = DAG.UpdateNodeOperands( - ST1, ST1->getChain(), NewSDVal, ST1->getOperand(2), - ST1->getOperand(3)); - return CombineTo(ST, SDValue(NewST1, 0)); - } - } - } - } // End ST subset of ST1 case. } } } @@ -16559,7 +17145,7 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { // There can be multiple store sequences on the same chain. // Keep trying to merge store sequences until we are unable to do so // or until we merge the last store on the chain. - bool Changed = MergeConsecutiveStores(ST); + bool Changed = mergeConsecutiveStores(ST); if (!Changed) break; // Return N as merge only uses CombineTo and no worklist clean // up is necessary. @@ -16835,6 +17421,10 @@ SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) { EVT SubVecVT = SubVec.getValueType(); EVT VT = DestVec.getValueType(); unsigned NumSrcElts = SubVecVT.getVectorNumElements(); + // If the source only has a single vector element, the cost of creating adding + // it to a vector is likely to exceed the cost of a insert_vector_elt. + if (NumSrcElts == 1) + return SDValue(); unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits(); unsigned NumMaskVals = ExtendRatio * NumSrcElts; @@ -16880,12 +17470,12 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { SDLoc DL(N); EVT VT = InVec.getValueType(); - unsigned NumElts = VT.getVectorNumElements(); + auto *IndexC = dyn_cast<ConstantSDNode>(EltNo); // Insert into out-of-bounds element is undefined. - if (auto *IndexC = dyn_cast<ConstantSDNode>(EltNo)) - if (IndexC->getZExtValue() >= VT.getVectorNumElements()) - return DAG.getUNDEF(VT); + if (IndexC && VT.isFixedLengthVector() && + IndexC->getZExtValue() >= VT.getVectorNumElements()) + return DAG.getUNDEF(VT); // Remove redundant insertions: // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x @@ -16893,17 +17483,25 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1)) return InVec; - auto *IndexC = dyn_cast<ConstantSDNode>(EltNo); if (!IndexC) { // If this is variable insert to undef vector, it might be better to splat: // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... > if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) { - SmallVector<SDValue, 8> Ops(NumElts, InVal); - return DAG.getBuildVector(VT, DL, Ops); + if (VT.isScalableVector()) + return DAG.getSplatVector(VT, DL, InVal); + else { + SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), InVal); + return DAG.getBuildVector(VT, DL, Ops); + } } return SDValue(); } + if (VT.isScalableVector()) + return SDValue(); + + unsigned NumElts = VT.getVectorNumElements(); + // We must know which element is being inserted for folds below here. unsigned Elt = IndexC->getZExtValue(); if (SDValue Shuf = combineInsertEltToShuffle(N, Elt)) @@ -16968,11 +17566,12 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT, EVT ResultVT = EVE->getValueType(0); EVT VecEltVT = InVecVT.getVectorElementType(); - unsigned Align = OriginalLoad->getAlignment(); - unsigned NewAlign = DAG.getDataLayout().getABITypeAlignment( + Align Alignment = OriginalLoad->getAlign(); + Align NewAlign = DAG.getDataLayout().getABITypeAlign( VecEltVT.getTypeForEVT(*DAG.getContext())); - if (NewAlign > Align || !TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT)) + if (NewAlign > Alignment || + !TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT)) return SDValue(); ISD::LoadExtType ExtTy = ResultVT.bitsGT(VecEltVT) ? @@ -16980,7 +17579,7 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT, if (!TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT)) return SDValue(); - Align = NewAlign; + Alignment = NewAlign; SDValue NewPtr = OriginalLoad->getBasePtr(); SDValue Offset; @@ -17020,13 +17619,13 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT, : ISD::EXTLOAD; Load = DAG.getExtLoad(ExtType, SDLoc(EVE), ResultVT, OriginalLoad->getChain(), NewPtr, MPI, VecEltVT, - Align, OriginalLoad->getMemOperand()->getFlags(), + Alignment, OriginalLoad->getMemOperand()->getFlags(), OriginalLoad->getAAInfo()); Chain = Load.getValue(1); } else { - Load = DAG.getLoad(VecEltVT, SDLoc(EVE), OriginalLoad->getChain(), NewPtr, - MPI, Align, OriginalLoad->getMemOperand()->getFlags(), - OriginalLoad->getAAInfo()); + Load = DAG.getLoad( + VecEltVT, SDLoc(EVE), OriginalLoad->getChain(), NewPtr, MPI, Alignment, + OriginalLoad->getMemOperand()->getFlags(), OriginalLoad->getAAInfo()); Chain = Load.getValue(1); if (ResultVT.bitsLT(VecEltVT)) Load = DAG.getNode(ISD::TRUNCATE, SDLoc(EVE), ResultVT, Load); @@ -17102,6 +17701,10 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { // (vextract (scalar_to_vector val, 0) -> val if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) { + // Only 0'th element of SCALAR_TO_VECTOR is defined. + if (DAG.isKnownNeverZero(Index)) + return DAG.getUNDEF(ScalarVT); + // Check if the result type doesn't match the inserted element type. A // SCALAR_TO_VECTOR may truncate the inserted element and the // EXTRACT_VECTOR_ELT may widen the extracted vector. @@ -17115,15 +17718,21 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { // extract_vector_elt of out-of-bounds element -> UNDEF auto *IndexC = dyn_cast<ConstantSDNode>(Index); - unsigned NumElts = VecVT.getVectorNumElements(); - if (IndexC && IndexC->getAPIntValue().uge(NumElts)) + if (IndexC && VecVT.isFixedLengthVector() && + IndexC->getAPIntValue().uge(VecVT.getVectorNumElements())) return DAG.getUNDEF(ScalarVT); // extract_vector_elt (build_vector x, y), 1 -> y - if (IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR && + if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) || + VecOp.getOpcode() == ISD::SPLAT_VECTOR) && TLI.isTypeLegal(VecVT) && (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) { - SDValue Elt = VecOp.getOperand(IndexC->getZExtValue()); + assert((VecOp.getOpcode() != ISD::BUILD_VECTOR || + VecVT.isFixedLengthVector()) && + "BUILD_VECTOR used for scalable vectors"); + unsigned IndexVal = + VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0; + SDValue Elt = VecOp.getOperand(IndexVal); EVT InEltVT = Elt.getValueType(); // Sometimes build_vector's scalar input types do not match result type. @@ -17134,6 +17743,15 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { // converts. } + if (VecVT.isScalableVector()) + return SDValue(); + + // All the code from this point onwards assumes fixed width vectors, but it's + // possible that some of the combinations could be made to work for scalable + // vectors too. + unsigned NumElts = VecVT.getVectorNumElements(); + unsigned VecEltBitWidth = VecVT.getScalarSizeInBits(); + // TODO: These transforms should not require the 'hasOneUse' restriction, but // there are regressions on multiple targets without it. We can end up with a // mess of scalar and vector code if we reduce only part of the DAG to scalar. @@ -17157,7 +17775,6 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { "Extract element and scalar to vector can't change element type " "from FP to integer."); unsigned XBitWidth = X.getValueSizeInBits(); - unsigned VecEltBitWidth = VecVT.getScalarSizeInBits(); BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1; // An extract element return value type can be wider than its vector @@ -17215,9 +17832,8 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { // FIXME: Should really be just isOperationLegalOrCustom. TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) || TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) { - EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout()); return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec, - DAG.getConstant(OrigElt, DL, IndexTy)); + DAG.getVectorIdxConstant(OrigElt, DL)); } } @@ -17241,6 +17857,14 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { AddToWorklist(N); return SDValue(N, 0); } + APInt DemandedBits = APInt::getAllOnesValue(VecEltBitWidth); + if (SimplifyDemandedBits(VecOp, DemandedBits, DemandedElts, true)) { + // We simplified the vector operand of this extract element. If this + // extract is not dead, visit it again so it is folded properly. + if (N->getOpcode() != ISD::DELETED_NODE) + AddToWorklist(N); + return SDValue(N, 0); + } } // Everything under here is trying to match an extract of a loaded value. @@ -17326,6 +17950,30 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts; Index = DAG.getConstant(Elt, DL, Index.getValueType()); } + } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged && + VecVT.getVectorElementType() == ScalarVT && + (!LegalTypes || + TLI.isTypeLegal( + VecOp.getOperand(0).getValueType().getVectorElementType()))) { + // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0 + // -> extract_vector_elt a, 0 + // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1 + // -> extract_vector_elt a, 1 + // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2 + // -> extract_vector_elt b, 0 + // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3 + // -> extract_vector_elt b, 1 + SDLoc SL(N); + EVT ConcatVT = VecOp.getOperand(0).getValueType(); + unsigned ConcatNumElts = ConcatVT.getVectorNumElements(); + SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, SL, + Index.getValueType()); + + SDValue ConcatOp = VecOp.getOperand(Elt / ConcatNumElts); + SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, + ConcatVT.getVectorElementType(), + ConcatOp, NewIdx); + return DAG.getNode(ISD::BITCAST, SL, ScalarVT, Elt); } // Make sure we found a non-volatile load and the extractelement is @@ -17407,6 +18055,11 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) { if (!ValidTypes) return SDValue(); + // If we already have a splat buildvector, then don't fold it if it means + // introducing zeros. + if (!AllAnyExt && DAG.isSplatValue(SDValue(N, 0), /*AllowUndefs*/ true)) + return SDValue(); + bool isLE = DAG.getDataLayout().isLittleEndian(); unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits(); assert(ElemRatio > 1 && "Invalid element size ratio"); @@ -17453,12 +18106,89 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) { return DAG.getBitcast(VT, BV); } +// Simplify (build_vec (trunc $1) +// (trunc (srl $1 half-width)) +// (trunc (srl $1 (2 * half-width))) …) +// to (bitcast $1) +SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) { + assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector"); + + // Only for little endian + if (!DAG.getDataLayout().isLittleEndian()) + return SDValue(); + + SDLoc DL(N); + EVT VT = N->getValueType(0); + EVT OutScalarTy = VT.getScalarType(); + uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits(); + + // Only for power of two types to be sure that bitcast works well + if (!isPowerOf2_64(ScalarTypeBitsize)) + return SDValue(); + + unsigned NumInScalars = N->getNumOperands(); + + // Look through bitcasts + auto PeekThroughBitcast = [](SDValue Op) { + if (Op.getOpcode() == ISD::BITCAST) + return Op.getOperand(0); + return Op; + }; + + // The source value where all the parts are extracted. + SDValue Src; + for (unsigned i = 0; i != NumInScalars; ++i) { + SDValue In = PeekThroughBitcast(N->getOperand(i)); + // Ignore undef inputs. + if (In.isUndef()) continue; + + if (In.getOpcode() != ISD::TRUNCATE) + return SDValue(); + + In = PeekThroughBitcast(In.getOperand(0)); + + if (In.getOpcode() != ISD::SRL) { + // For now only build_vec without shuffling, handle shifts here in the + // future. + if (i != 0) + return SDValue(); + + Src = In; + } else { + // In is SRL + SDValue part = PeekThroughBitcast(In.getOperand(0)); + + if (!Src) { + Src = part; + } else if (Src != part) { + // Vector parts do not stem from the same variable + return SDValue(); + } + + SDValue ShiftAmtVal = In.getOperand(1); + if (!isa<ConstantSDNode>(ShiftAmtVal)) + return SDValue(); + + uint64_t ShiftAmt = In.getNode()->getConstantOperandVal(1); + + // The extracted value is not extracted at the right position + if (ShiftAmt != i * ScalarTypeBitsize) + return SDValue(); + } + } + + // Only cast if the size is the same + if (Src.getValueType().getSizeInBits() != VT.getSizeInBits()) + return SDValue(); + + return DAG.getBitcast(VT, Src); +} + SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N, ArrayRef<int> VectorMask, SDValue VecIn1, SDValue VecIn2, unsigned LeftIdx, bool DidSplitVec) { - MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout()); - SDValue ZeroIdx = DAG.getConstant(0, DL, IdxTy); + SDValue ZeroIdx = DAG.getVectorIdxConstant(0, DL); EVT VT = N->getValueType(0); EVT InVT1 = VecIn1.getValueType(); @@ -17492,7 +18222,7 @@ SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N, // If we only have one input vector, and it's twice the size of the // output, split it in two. VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, - DAG.getConstant(NumElems, DL, IdxTy)); + DAG.getVectorIdxConstant(NumElems, DL)); VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx); // Since we now have shorter input vectors, adjust the offset of the // second vector's start. @@ -17699,6 +18429,9 @@ SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) { return SDValue(); SDValue ExtractedFromVec = Op.getOperand(0); + if (ExtractedFromVec.getValueType().isScalableVector()) + return SDValue(); + const APInt &ExtractIdx = Op.getConstantOperandAPInt(1); if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements())) return SDValue(); @@ -17733,7 +18466,6 @@ SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) { unsigned NearestPow2 = 0; SDValue Vec = VecIn.back(); EVT InVT = Vec.getValueType(); - MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout()); SmallVector<unsigned, 8> IndexVec(NumElems, 0); for (unsigned i = 0; i < NumElems; i++) { @@ -17752,9 +18484,9 @@ SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) { InVT.getVectorElementType(), SplitSize); if (TLI.isTypeLegal(SplitVT)) { SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec, - DAG.getConstant(SplitSize, DL, IdxTy)); + DAG.getVectorIdxConstant(SplitSize, DL)); SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec, - DAG.getConstant(0, DL, IdxTy)); + DAG.getVectorIdxConstant(0, DL)); VecIn.pop_back(); VecIn.push_back(VecIn1); VecIn.push_back(VecIn2); @@ -17986,6 +18718,9 @@ SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) { if (SDValue V = reduceBuildVecExtToExtBuildVec(N)) return V; + if (SDValue V = reduceBuildVecTruncToBitCast(N)) + return V; + if (SDValue V = reduceBuildVecToShuffle(N)) return V; @@ -18080,6 +18815,7 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) { // What vector are we extracting the subvector from and at what index? SDValue ExtVec = Op.getOperand(0); + int ExtIdx = Op.getConstantOperandVal(1); // We want the EVT of the original extraction to correctly scale the // extraction index. @@ -18092,10 +18828,6 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) { continue; } - if (!isa<ConstantSDNode>(Op.getOperand(1))) - return SDValue(); - int ExtIdx = Op.getConstantOperandVal(1); - // Ensure that we are extracting a subvector from a vector the same // size as the result. if (ExtVT.getSizeInBits() != VT.getSizeInBits()) @@ -18129,6 +18861,69 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) { DAG.getBitcast(VT, SV1), Mask, DAG); } +static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) { + unsigned CastOpcode = N->getOperand(0).getOpcode(); + switch (CastOpcode) { + case ISD::SINT_TO_FP: + case ISD::UINT_TO_FP: + case ISD::FP_TO_SINT: + case ISD::FP_TO_UINT: + // TODO: Allow more opcodes? + // case ISD::BITCAST: + // case ISD::TRUNCATE: + // case ISD::ZERO_EXTEND: + // case ISD::SIGN_EXTEND: + // case ISD::FP_EXTEND: + break; + default: + return SDValue(); + } + + EVT SrcVT = N->getOperand(0).getOperand(0).getValueType(); + if (!SrcVT.isVector()) + return SDValue(); + + // All operands of the concat must be the same kind of cast from the same + // source type. + SmallVector<SDValue, 4> SrcOps; + for (SDValue Op : N->ops()) { + if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() || + Op.getOperand(0).getValueType() != SrcVT) + return SDValue(); + SrcOps.push_back(Op.getOperand(0)); + } + + // The wider cast must be supported by the target. This is unusual because + // the operation support type parameter depends on the opcode. In addition, + // check the other type in the cast to make sure this is really legal. + EVT VT = N->getValueType(0); + EVT SrcEltVT = SrcVT.getVectorElementType(); + unsigned NumElts = SrcVT.getVectorElementCount().Min * N->getNumOperands(); + EVT ConcatSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcEltVT, NumElts); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + switch (CastOpcode) { + case ISD::SINT_TO_FP: + case ISD::UINT_TO_FP: + if (!TLI.isOperationLegalOrCustom(CastOpcode, ConcatSrcVT) || + !TLI.isTypeLegal(VT)) + return SDValue(); + break; + case ISD::FP_TO_SINT: + case ISD::FP_TO_UINT: + if (!TLI.isOperationLegalOrCustom(CastOpcode, VT) || + !TLI.isTypeLegal(ConcatSrcVT)) + return SDValue(); + break; + default: + llvm_unreachable("Unexpected cast opcode"); + } + + // concat (cast X), (cast Y)... -> cast (concat X, Y...) + SDLoc DL(N); + SDValue NewConcat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatSrcVT, SrcOps); + return DAG.getNode(CastOpcode, DL, VT, NewConcat); +} + SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { // If we only have one input vector, we don't need to do any concatenation. if (N->getNumOperands() == 1) @@ -18256,6 +19051,9 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { if (SDValue V = combineConcatVectorOfExtracts(N, DAG)) return V; + if (SDValue V = combineConcatVectorOfCasts(N, DAG)) + return V; + // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR // nodes often generate nop CONCAT_VECTOR nodes. // Scan the CONCAT_VECTOR operands and look for a CONCAT operations that @@ -18287,14 +19085,9 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { return SDValue(); } - auto *CS = dyn_cast<ConstantSDNode>(Op.getOperand(1)); - // The extract index must be constant. - if (!CS) - return SDValue(); - // Check that we are reading from the identity index. unsigned IdentityIndex = i * PartNumElem; - if (CS->getAPIntValue() != IdentityIndex) + if (Op.getConstantOperandAPInt(1) != IdentityIndex) return SDValue(); } @@ -18377,6 +19170,15 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) { if (!TLI.isBinOp(BOpcode) || BinOp.getNode()->getNumValues() != 1) return SDValue(); + // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be + // reduced to the unary fneg when it is visited, and we probably want to deal + // with fneg in a target-specific way. + if (BOpcode == ISD::FSUB) { + auto *C = isConstOrConstSplatFP(BinOp.getOperand(0), /*AllowUndefs*/ true); + if (C && C->getValueAPF().isNegZero()) + return SDValue(); + } + // The binop must be a vector type, so we can extract some fraction of it. EVT WideBVT = BinOp.getValueType(); if (!WideBVT.isVector()) @@ -18412,12 +19214,11 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) { // bitcasted. unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements(); unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements(); - EVT ExtBOIdxVT = Extract->getOperand(1).getValueType(); if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) && BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) { // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N) SDLoc DL(Extract); - SDValue NewExtIndex = DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT); + SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL); SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, BinOp.getOperand(0), NewExtIndex); SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, @@ -18457,7 +19258,7 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) { // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC) // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN SDLoc DL(Extract); - SDValue IndexC = DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT); + SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL); SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL) : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, BinOp.getOperand(0), IndexC); @@ -18489,6 +19290,26 @@ static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) { // Allow targets to opt-out. EVT VT = Extract->getValueType(0); + + // We can only create byte sized loads. + if (!VT.isByteSized()) + return SDValue(); + + unsigned Index = ExtIdx->getZExtValue(); + unsigned NumElts = VT.getVectorNumElements(); + + // If the index is a multiple of the extract element count, we can offset the + // address by the store size multiplied by the subvector index. Otherwise if + // the scalar type is byte sized, we can just use the index multiplied by + // the element size in bytes as the offset. + unsigned Offset; + if (Index % NumElts == 0) + Offset = (Index / NumElts) * VT.getStoreSize(); + else if (VT.getScalarType().isByteSized()) + Offset = Index * VT.getScalarType().getStoreSize(); + else + return SDValue(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT)) return SDValue(); @@ -18496,8 +19317,7 @@ static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) { // The narrow load will be offset from the base address of the old load if // we are extracting from something besides index 0 (little-endian). SDLoc DL(Extract); - SDValue BaseAddr = Ld->getOperand(1); - unsigned Offset = ExtIdx->getZExtValue() * VT.getScalarType().getStoreSize(); + SDValue BaseAddr = Ld->getBasePtr(); // TODO: Use "BaseIndexOffset" to make this more effective. SDValue NewAddr = DAG.getMemBasePlusOffset(BaseAddr, Offset, DL); @@ -18512,6 +19332,7 @@ static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) { SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { EVT NVT = N->getValueType(0); SDValue V = N->getOperand(0); + uint64_t ExtIdx = N->getConstantOperandVal(1); // Extract from UNDEF is UNDEF. if (V.isUndef()) @@ -18523,9 +19344,7 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { // Combine an extract of an extract into a single extract_subvector. // ext (ext X, C), 0 --> ext X, C - SDValue Index = N->getOperand(1); - if (isNullConstant(Index) && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && - V.hasOneUse() && isa<ConstantSDNode>(V.getOperand(1))) { + if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) { if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(), V.getConstantOperandVal(1)) && TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) { @@ -18536,21 +19355,20 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { // Try to move vector bitcast after extract_subv by scaling extraction index: // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index') - if (isa<ConstantSDNode>(Index) && V.getOpcode() == ISD::BITCAST && + if (V.getOpcode() == ISD::BITCAST && V.getOperand(0).getValueType().isVector()) { SDValue SrcOp = V.getOperand(0); EVT SrcVT = SrcOp.getValueType(); - unsigned SrcNumElts = SrcVT.getVectorNumElements(); - unsigned DestNumElts = V.getValueType().getVectorNumElements(); + unsigned SrcNumElts = SrcVT.getVectorMinNumElements(); + unsigned DestNumElts = V.getValueType().getVectorMinNumElements(); if ((SrcNumElts % DestNumElts) == 0) { unsigned SrcDestRatio = SrcNumElts / DestNumElts; - unsigned NewExtNumElts = NVT.getVectorNumElements() * SrcDestRatio; + ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio; EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(), - NewExtNumElts); + NewExtEC); if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) { - unsigned IndexValScaled = N->getConstantOperandVal(1) * SrcDestRatio; SDLoc DL(N); - SDValue NewIndex = DAG.getIntPtrConstant(IndexValScaled, DL); + SDValue NewIndex = DAG.getVectorIdxConstant(ExtIdx * SrcDestRatio, DL); SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT, V.getOperand(0), NewIndex); return DAG.getBitcast(NVT, NewExtract); @@ -18558,34 +19376,43 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { } if ((DestNumElts % SrcNumElts) == 0) { unsigned DestSrcRatio = DestNumElts / SrcNumElts; - if ((NVT.getVectorNumElements() % DestSrcRatio) == 0) { - unsigned NewExtNumElts = NVT.getVectorNumElements() / DestSrcRatio; - EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(), - SrcVT.getScalarType(), NewExtNumElts); - if ((N->getConstantOperandVal(1) % DestSrcRatio) == 0 && - TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) { - unsigned IndexValScaled = N->getConstantOperandVal(1) / DestSrcRatio; + if ((NVT.getVectorMinNumElements() % DestSrcRatio) == 0) { + ElementCount NewExtEC = NVT.getVectorElementCount() / DestSrcRatio; + EVT ScalarVT = SrcVT.getScalarType(); + if ((ExtIdx % DestSrcRatio) == 0) { SDLoc DL(N); - SDValue NewIndex = DAG.getIntPtrConstant(IndexValScaled, DL); - SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT, - V.getOperand(0), NewIndex); - return DAG.getBitcast(NVT, NewExtract); + unsigned IndexValScaled = ExtIdx / DestSrcRatio; + EVT NewExtVT = + EVT::getVectorVT(*DAG.getContext(), ScalarVT, NewExtEC); + if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) { + SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL); + SDValue NewExtract = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT, + V.getOperand(0), NewIndex); + return DAG.getBitcast(NVT, NewExtract); + } + if (NewExtEC == 1 && + TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, ScalarVT)) { + SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL); + SDValue NewExtract = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, + V.getOperand(0), NewIndex); + return DAG.getBitcast(NVT, NewExtract); + } } } } } - if (V.getOpcode() == ISD::CONCAT_VECTORS && isa<ConstantSDNode>(Index)) { + if (V.getOpcode() == ISD::CONCAT_VECTORS) { + unsigned ExtNumElts = NVT.getVectorMinNumElements(); EVT ConcatSrcVT = V.getOperand(0).getValueType(); assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() && "Concat and extract subvector do not change element type"); - - unsigned ExtIdx = N->getConstantOperandVal(1); - unsigned ExtNumElts = NVT.getVectorNumElements(); - assert(ExtIdx % ExtNumElts == 0 && + assert((ExtIdx % ExtNumElts) == 0 && "Extract index is not a multiple of the input vector length."); - unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorNumElements(); + unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements(); unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts; // If the concatenated source types match this extract, it's a direct @@ -18599,15 +19426,14 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { // concat operand. Example: // v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 --> // v2i8 extract_subvec v8i8 Y, 6 - if (ConcatSrcNumElts % ExtNumElts == 0) { + if (NVT.isFixedLengthVector() && ConcatSrcNumElts % ExtNumElts == 0) { SDLoc DL(N); unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts; assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts && "Trying to extract from >1 concat operand?"); assert(NewExtIdx % ExtNumElts == 0 && "Extract index is not a multiple of the input vector length."); - MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout()); - SDValue NewIndexC = DAG.getConstant(NewExtIdx, DL, IdxTy); + SDValue NewIndexC = DAG.getVectorIdxConstant(NewExtIdx, DL); return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, V.getOperand(ConcatOpIdx), NewIndexC); } @@ -18617,37 +19443,33 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { // If the input is a build vector. Try to make a smaller build vector. if (V.getOpcode() == ISD::BUILD_VECTOR) { - if (auto *IdxC = dyn_cast<ConstantSDNode>(Index)) { - EVT InVT = V.getValueType(); - unsigned ExtractSize = NVT.getSizeInBits(); - unsigned EltSize = InVT.getScalarSizeInBits(); - // Only do this if we won't split any elements. - if (ExtractSize % EltSize == 0) { - unsigned NumElems = ExtractSize / EltSize; - EVT EltVT = InVT.getVectorElementType(); - EVT ExtractVT = NumElems == 1 ? EltVT - : EVT::getVectorVT(*DAG.getContext(), - EltVT, NumElems); - if ((Level < AfterLegalizeDAG || - (NumElems == 1 || - TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) && - (!LegalTypes || TLI.isTypeLegal(ExtractVT))) { - unsigned IdxVal = IdxC->getZExtValue(); - IdxVal *= NVT.getScalarSizeInBits(); - IdxVal /= EltSize; - - if (NumElems == 1) { - SDValue Src = V->getOperand(IdxVal); - if (EltVT != Src.getValueType()) - Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src); - return DAG.getBitcast(NVT, Src); - } - - // Extract the pieces from the original build_vector. - SDValue BuildVec = DAG.getBuildVector( - ExtractVT, SDLoc(N), V->ops().slice(IdxVal, NumElems)); - return DAG.getBitcast(NVT, BuildVec); + EVT InVT = V.getValueType(); + unsigned ExtractSize = NVT.getSizeInBits(); + unsigned EltSize = InVT.getScalarSizeInBits(); + // Only do this if we won't split any elements. + if (ExtractSize % EltSize == 0) { + unsigned NumElems = ExtractSize / EltSize; + EVT EltVT = InVT.getVectorElementType(); + EVT ExtractVT = + NumElems == 1 ? EltVT + : EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems); + if ((Level < AfterLegalizeDAG || + (NumElems == 1 || + TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) && + (!LegalTypes || TLI.isTypeLegal(ExtractVT))) { + unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize; + + if (NumElems == 1) { + SDValue Src = V->getOperand(IdxVal); + if (EltVT != Src.getValueType()) + Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src); + return DAG.getBitcast(NVT, Src); } + + // Extract the pieces from the original build_vector. + SDValue BuildVec = DAG.getBuildVector(ExtractVT, SDLoc(N), + V->ops().slice(IdxVal, NumElems)); + return DAG.getBitcast(NVT, BuildVec); } } } @@ -18659,23 +19481,19 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { if (!NVT.bitsEq(SmallVT)) return SDValue(); - // Only handle cases where both indexes are constants. - auto *ExtIdx = dyn_cast<ConstantSDNode>(Index); - auto *InsIdx = dyn_cast<ConstantSDNode>(V.getOperand(2)); - if (InsIdx && ExtIdx) { - // Combine: - // (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx) - // Into: - // indices are equal or bit offsets are equal => V1 - // otherwise => (extract_subvec V1, ExtIdx) - if (InsIdx->getZExtValue() * SmallVT.getScalarSizeInBits() == - ExtIdx->getZExtValue() * NVT.getScalarSizeInBits()) - return DAG.getBitcast(NVT, V.getOperand(1)); - return DAG.getNode( - ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, - DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)), - Index); - } + // Combine: + // (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx) + // Into: + // indices are equal or bit offsets are equal => V1 + // otherwise => (extract_subvec V1, ExtIdx) + uint64_t InsIdx = V.getConstantOperandVal(2); + if (InsIdx * SmallVT.getScalarSizeInBits() == + ExtIdx * NVT.getScalarSizeInBits()) + return DAG.getBitcast(NVT, V.getOperand(1)); + return DAG.getNode( + ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, + DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)), + N->getOperand(1)); } if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG)) @@ -19064,6 +19882,57 @@ static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf, NewMask); } +/// Combine shuffle of shuffle of the form: +/// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X +static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf, + SelectionDAG &DAG) { + if (!OuterShuf->getOperand(1).isUndef()) + return SDValue(); + auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0)); + if (!InnerShuf || !InnerShuf->getOperand(1).isUndef()) + return SDValue(); + + ArrayRef<int> OuterMask = OuterShuf->getMask(); + ArrayRef<int> InnerMask = InnerShuf->getMask(); + unsigned NumElts = OuterMask.size(); + assert(NumElts == InnerMask.size() && "Mask length mismatch"); + SmallVector<int, 32> CombinedMask(NumElts, -1); + int SplatIndex = -1; + for (unsigned i = 0; i != NumElts; ++i) { + // Undef lanes remain undef. + int OuterMaskElt = OuterMask[i]; + if (OuterMaskElt == -1) + continue; + + // Peek through the shuffle masks to get the underlying source element. + int InnerMaskElt = InnerMask[OuterMaskElt]; + if (InnerMaskElt == -1) + continue; + + // Initialize the splatted element. + if (SplatIndex == -1) + SplatIndex = InnerMaskElt; + + // Non-matching index - this is not a splat. + if (SplatIndex != InnerMaskElt) + return SDValue(); + + CombinedMask[i] = InnerMaskElt; + } + assert((all_of(CombinedMask, [](int M) { return M == -1; }) || + getSplatIndex(CombinedMask) != -1) && + "Expected a splat mask"); + + // TODO: The transform may be a win even if the mask is not legal. + EVT VT = OuterShuf->getValueType(0); + assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types"); + if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT)) + return SDValue(); + + return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0), + InnerShuf->getOperand(1), CombinedMask); +} + /// If the shuffle mask is taking exactly one element from the first vector /// operand and passing through all other elements from the second vector /// operand, return the index of the mask element that is choosing an element @@ -19136,8 +20005,7 @@ static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf, // element used. Therefore, our new insert element occurs at the shuffle's // mask index value, not the insert's index value. // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C' - SDValue NewInsIndex = DAG.getConstant(ShufOp0Index, SDLoc(Shuf), - Op0.getOperand(2).getValueType()); + SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf)); return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(), Op1, Op0.getOperand(1), NewInsIndex); } @@ -19223,6 +20091,9 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { if (SDValue V = combineShuffleOfSplatVal(SVN, DAG)) return V; + if (SDValue V = formSplatFromShuffles(SVN, DAG)) + return V; + // If it is a splat, check if the argument vector is another splat or a // build_vector. if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) { @@ -19234,7 +20105,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { SDValue L = N0.getOperand(0), R = N0.getOperand(1); SDLoc DL(N); EVT EltVT = VT.getScalarType(); - SDValue Index = DAG.getIntPtrConstant(SplatIndex, DL); + SDValue Index = DAG.getVectorIdxConstant(SplatIndex, DL); SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index); SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index); SDValue NewBO = DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR, @@ -19354,16 +20225,6 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() && N1.isUndef() && Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) { - auto ScaleShuffleMask = [](ArrayRef<int> Mask, int Scale) { - if (Scale == 1) - return SmallVector<int, 8>(Mask.begin(), Mask.end()); - - SmallVector<int, 8> NewMask; - for (int M : Mask) - for (int s = 0; s != Scale; ++s) - NewMask.push_back(M < 0 ? -1 : Scale * M + s); - return NewMask; - }; SDValue BC0 = peekThroughOneUseBitcasts(N0); if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) { @@ -19383,10 +20244,10 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { // Scale the shuffle masks to the smaller scalar type. ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0); - SmallVector<int, 8> InnerMask = - ScaleShuffleMask(InnerSVN->getMask(), InnerScale); - SmallVector<int, 8> OuterMask = - ScaleShuffleMask(SVN->getMask(), OuterScale); + SmallVector<int, 8> InnerMask; + SmallVector<int, 8> OuterMask; + narrowShuffleMaskElts(InnerScale, InnerSVN->getMask(), InnerMask); + narrowShuffleMaskElts(OuterScale, SVN->getMask(), OuterMask); // Merge the shuffle masks. SmallVector<int, 8> NewMask; @@ -19547,7 +20408,9 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) { // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern // with a VECTOR_SHUFFLE and possible truncate. - if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT) { + if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + VT.isFixedLengthVector() && + InVal->getOperand(0).getValueType().isFixedLengthVector()) { SDValue InVec = InVal->getOperand(0); SDValue EltNo = InVal->getOperand(1); auto InVecT = InVec.getValueType(); @@ -19576,11 +20439,10 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) { return LegalShuffle; // If not we must truncate the vector. if (VT.getVectorNumElements() != InVecT.getVectorNumElements()) { - MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout()); - SDValue ZeroIdx = DAG.getConstant(0, SDLoc(N), IdxTy); - EVT SubVT = - EVT::getVectorVT(*DAG.getContext(), InVecT.getVectorElementType(), - VT.getVectorNumElements()); + SDValue ZeroIdx = DAG.getVectorIdxConstant(0, SDLoc(N)); + EVT SubVT = EVT::getVectorVT(*DAG.getContext(), + InVecT.getVectorElementType(), + VT.getVectorNumElements()); return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT, LegalShuffle, ZeroIdx); } @@ -19597,6 +20459,7 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); SDValue N2 = N->getOperand(2); + uint64_t InsIdx = N->getConstantOperandVal(2); // If inserting an UNDEF, just return the original vector. if (N1.isUndef()) @@ -19657,11 +20520,6 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0, N1.getOperand(1), N2); - if (!isa<ConstantSDNode>(N2)) - return SDValue(); - - uint64_t InsIdx = cast<ConstantSDNode>(N2)->getZExtValue(); - // Push subvector bitcasts to the output, adjusting the index as we go. // insert_subvector(bitcast(v), bitcast(s), c1) // -> bitcast(insert_subvector(v, s, c2)) @@ -19676,19 +20534,18 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { EVT NewVT; SDLoc DL(N); SDValue NewIdx; - MVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout()); LLVMContext &Ctx = *DAG.getContext(); unsigned NumElts = VT.getVectorNumElements(); unsigned EltSizeInBits = VT.getScalarSizeInBits(); if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) { unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits(); NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale); - NewIdx = DAG.getConstant(InsIdx * Scale, DL, IdxVT); + NewIdx = DAG.getVectorIdxConstant(InsIdx * Scale, DL); } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) { unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits; if ((NumElts % Scale) == 0 && (InsIdx % Scale) == 0) { NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts / Scale); - NewIdx = DAG.getConstant(InsIdx / Scale, DL, IdxVT); + NewIdx = DAG.getVectorIdxConstant(InsIdx / Scale, DL); } } if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) { @@ -19704,8 +20561,7 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { // (insert_subvector (insert_subvector A, Idx0), Idx1) // -> (insert_subvector (insert_subvector A, Idx1), Idx0) if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() && - N1.getValueType() == N0.getOperand(1).getValueType() && - isa<ConstantSDNode>(N0.getOperand(2))) { + N1.getValueType() == N0.getOperand(1).getValueType()) { unsigned OtherIdx = N0.getConstantOperandVal(2); if (InsIdx < OtherIdx) { // Swap nodes. @@ -19722,10 +20578,8 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() && N0.getOperand(0).getValueType() == N1.getValueType()) { unsigned Factor = N1.getValueType().getVectorNumElements(); - SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end()); - Ops[cast<ConstantSDNode>(N2)->getZExtValue() / Factor] = N1; - + Ops[InsIdx / Factor] = N1; return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops); } @@ -19769,9 +20623,9 @@ SDValue DAGCombiner::visitVECREDUCE(SDNode *N) { // VECREDUCE over 1-element vector is just an extract. if (VT.getVectorNumElements() == 1) { SDLoc dl(N); - SDValue Res = DAG.getNode( - ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0, - DAG.getConstant(0, dl, TLI.getVectorIdxTy(DAG.getDataLayout()))); + SDValue Res = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0, + DAG.getVectorIdxConstant(0, dl)); if (Res.getValueType() != N->getValueType(0)) Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res); return Res; @@ -19904,10 +20758,9 @@ static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG) { return SDValue(); SDLoc DL(N); - SDValue IndexC = - DAG.getConstant(Index0, DL, TLI.getVectorIdxTy(DAG.getDataLayout())); - SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, N0, IndexC); - SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, N1, IndexC); + SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL); + SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src0, IndexC); + SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src1, IndexC); SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags()); // If all lanes but 1 are undefined, no need to splat the scalar result. @@ -19937,6 +20790,7 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) { SDValue Ops[] = {LHS, RHS}; EVT VT = N->getValueType(0); unsigned Opcode = N->getOpcode(); + SDNodeFlags Flags = N->getFlags(); // See if we can constant fold the vector operation. if (SDValue Fold = DAG.FoldConstantVectorArithmetic( @@ -19960,10 +20814,37 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) { (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) { SDLoc DL(N); SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0), - RHS.getOperand(0), N->getFlags()); + RHS.getOperand(0), Flags); SDValue UndefV = LHS.getOperand(1); return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask()); } + + // Try to sink a splat shuffle after a binop with a uniform constant. + // This is limited to cases where neither the shuffle nor the constant have + // undefined elements because that could be poison-unsafe or inhibit + // demanded elements analysis. It is further limited to not change a splat + // of an inserted scalar because that may be optimized better by + // load-folding or other target-specific behaviors. + if (isConstOrConstSplat(RHS) && Shuf0 && is_splat(Shuf0->getMask()) && + Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() && + Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) { + // binop (splat X), (splat C) --> splat (binop X, C) + SDLoc DL(N); + SDValue X = Shuf0->getOperand(0); + SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, X, RHS, Flags); + return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT), + Shuf0->getMask()); + } + if (isConstOrConstSplat(LHS) && Shuf1 && is_splat(Shuf1->getMask()) && + Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() && + Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) { + // binop (splat C), (splat X) --> splat (binop C, X) + SDLoc DL(N); + SDValue X = Shuf1->getOperand(0); + SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS, X, Flags); + return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT), + Shuf1->getMask()); + } } // The following pattern is likely to emerge with vector reduction ops. Moving @@ -20361,8 +21242,8 @@ SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset( // Create a ConstantArray of the two constants. Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts); SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()), - TD.getPrefTypeAlignment(FPTy)); - unsigned Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlignment(); + TD.getPrefTypeAlign(FPTy)); + Align Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlign(); // Get offsets to the 0 and 1 elements of the array, so we can select between // them. @@ -20797,7 +21678,10 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, EVT CCVT = getSetCCResultType(VT); ISD::NodeType SelOpcode = VT.isVector() ? ISD::VSELECT : ISD::SELECT; DenormalMode DenormMode = DAG.getDenormalMode(VT); - if (DenormMode == DenormalMode::IEEE) { + if (DenormMode.Input == DenormalMode::IEEE) { + // This is specifically a check for the handling of denormal inputs, + // not the result. + // fabs(X) < SmallestNormal ? 0.0 : Est const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT); APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem); @@ -20849,9 +21733,11 @@ bool DAGCombiner::isAlias(SDNode *Op0, SDNode *Op1) const { : (LSN->getAddressingMode() == ISD::PRE_DEC) ? -1 * C->getSExtValue() : 0; + uint64_t Size = + MemoryLocation::getSizeOrUnknown(LSN->getMemoryVT().getStoreSize()); return {LSN->isVolatile(), LSN->isAtomic(), LSN->getBasePtr(), Offset /*base offset*/, - Optional<int64_t>(LSN->getMemoryVT().getStoreSize()), + Optional<int64_t>(Size), LSN->getMemOperand()}; } if (const auto *LN = cast<LifetimeSDNode>(N)) @@ -20911,21 +21797,24 @@ bool DAGCombiner::isAlias(SDNode *Op0, SDNode *Op1) const { // If we know required SrcValue1 and SrcValue2 have relatively large // alignment compared to the size and offset of the access, we may be able // to prove they do not alias. This check is conservative for now to catch - // cases created by splitting vector types. + // cases created by splitting vector types, it only works when the offsets are + // multiples of the size of the data. int64_t SrcValOffset0 = MUC0.MMO->getOffset(); int64_t SrcValOffset1 = MUC1.MMO->getOffset(); - unsigned OrigAlignment0 = MUC0.MMO->getBaseAlignment(); - unsigned OrigAlignment1 = MUC1.MMO->getBaseAlignment(); + Align OrigAlignment0 = MUC0.MMO->getBaseAlign(); + Align OrigAlignment1 = MUC1.MMO->getBaseAlign(); + auto &Size0 = MUC0.NumBytes; + auto &Size1 = MUC1.NumBytes; if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 && - MUC0.NumBytes.hasValue() && MUC1.NumBytes.hasValue() && - *MUC0.NumBytes == *MUC1.NumBytes && OrigAlignment0 > *MUC0.NumBytes) { - int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0; - int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1; + Size0.hasValue() && Size1.hasValue() && *Size0 == *Size1 && + OrigAlignment0 > *Size0 && SrcValOffset0 % *Size0 == 0 && + SrcValOffset1 % *Size1 == 0) { + int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value(); + int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value(); // There is no overlap between these relatively aligned accesses of // similar size. Return no alias. - if ((OffAlign0 + *MUC0.NumBytes) <= OffAlign1 || - (OffAlign1 + *MUC1.NumBytes) <= OffAlign0) + if ((OffAlign0 + *Size0) <= OffAlign1 || (OffAlign1 + *Size1) <= OffAlign0) return false; } @@ -20938,11 +21827,12 @@ bool DAGCombiner::isAlias(SDNode *Op0, SDNode *Op1) const { UseAA = false; #endif - if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue()) { + if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() && + Size0.hasValue() && Size1.hasValue()) { // Use alias analysis information. int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1); - int64_t Overlap0 = *MUC0.NumBytes + SrcValOffset0 - MinOffset; - int64_t Overlap1 = *MUC1.NumBytes + SrcValOffset1 - MinOffset; + int64_t Overlap0 = *Size0 + SrcValOffset0 - MinOffset; + int64_t Overlap1 = *Size1 + SrcValOffset1 - MinOffset; AliasResult AAResult = AA->alias( MemoryLocation(MUC0.MMO->getValue(), Overlap0, UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()), @@ -21099,10 +21989,10 @@ bool operator!=(const UnitT &, const UnitT &) { return false; } // redundant, as this function gets called when visiting every store // node, so why not let the work be done on each store as it's visited? // -// I believe this is mainly important because MergeConsecutiveStores +// I believe this is mainly important because mergeConsecutiveStores // is unable to deal with merging stores of different sizes, so unless // we improve the chains of all the potential candidates up-front -// before running MergeConsecutiveStores, it might only see some of +// before running mergeConsecutiveStores, it might only see some of // the nodes that will eventually be candidates, and then not be able // to go from a partially-merged state to the desired final // fully-merged state. @@ -21131,6 +22021,12 @@ bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) { if (BasePtr.getBase().isUndef()) return false; + // BaseIndexOffset assumes that offsets are fixed-size, which + // is not valid for scalable vectors where the offsets are + // scaled by `vscale`, so bail out early. + if (St->getMemoryVT().isScalableVector()) + return false; + // Add ST's interval. Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit); |