diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 2633 |
1 files changed, 1882 insertions, 751 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 0a3ebd73d272..de909cc10795 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -30,11 +30,14 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/CodeGen/ByteProvider.h" #include "llvm/CodeGen/DAGCombine.h" #include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineMemOperand.h" +#include "llvm/CodeGen/MachineValueType.h" #include "llvm/CodeGen/RuntimeLibcalls.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h" @@ -57,7 +60,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" -#include "llvm/Support/MachineValueType.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" @@ -169,7 +171,8 @@ namespace { /// them) when they are deleted from the underlying DAG. It relies on /// stable indices of nodes within the worklist. DenseMap<SDNode *, unsigned> WorklistMap; - /// This records all nodes attempted to add to the worklist since we + + /// This records all nodes attempted to be added to the worklist since we /// considered a new worklist entry. As we keep do not add duplicate nodes /// in the worklist, this is different from the tail of the worklist. SmallSetVector<SDNode *, 32> PruningList; @@ -262,7 +265,7 @@ namespace { /// Add to the worklist making sure its instance is at the back (next to be /// processed.) - void AddToWorklist(SDNode *N) { + void AddToWorklist(SDNode *N, bool IsCandidateForPruning = true) { assert(N->getOpcode() != ISD::DELETED_NODE && "Deleted Node added to Worklist"); @@ -271,7 +274,8 @@ namespace { if (N->getOpcode() == ISD::HANDLENODE) return; - ConsiderForPruning(N); + if (IsCandidateForPruning) + ConsiderForPruning(N); if (WorklistMap.insert(std::make_pair(N, Worklist.size())).second) Worklist.push_back(N); @@ -362,6 +366,11 @@ namespace { SDValue SplitIndexingFromLoad(LoadSDNode *LD); bool SliceUpLoad(SDNode *N); + // Looks up the chain to find a unique (unaliased) store feeding the passed + // load. If no such store is found, returns a nullptr. + // Note: This will look past a CALLSEQ_START if the load is chained to it so + // so that it can find stack stores for byval params. + StoreSDNode *getUniqueStoreFeeding(LoadSDNode *LD, int64_t &Offset); // Scalars have size 0 to distinguish from singleton vectors. SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD); bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val); @@ -417,11 +426,12 @@ namespace { SDValue visitSUBC(SDNode *N); SDValue visitSUBO(SDNode *N); SDValue visitADDE(SDNode *N); - SDValue visitADDCARRY(SDNode *N); + SDValue visitUADDO_CARRY(SDNode *N); SDValue visitSADDO_CARRY(SDNode *N); - SDValue visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, SDNode *N); + SDValue visitUADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, + SDNode *N); SDValue visitSUBE(SDNode *N); - SDValue visitSUBCARRY(SDNode *N); + SDValue visitUSUBO_CARRY(SDNode *N); SDValue visitSSUBO_CARRY(SDNode *N); SDValue visitMUL(SDNode *N); SDValue visitMULFIX(SDNode *N); @@ -434,6 +444,7 @@ namespace { SDValue visitMULHU(SDNode *N); SDValue visitMULHS(SDNode *N); SDValue visitAVG(SDNode *N); + SDValue visitABD(SDNode *N); SDValue visitSMUL_LOHI(SDNode *N); SDValue visitUMUL_LOHI(SDNode *N); SDValue visitMULO(SDNode *N); @@ -476,10 +487,12 @@ namespace { SDValue visitFREEZE(SDNode *N); SDValue visitBUILD_PAIR(SDNode *N); SDValue visitFADD(SDNode *N); + SDValue visitVP_FADD(SDNode *N); + SDValue visitVP_FSUB(SDNode *N); SDValue visitSTRICT_FADD(SDNode *N); SDValue visitFSUB(SDNode *N); SDValue visitFMUL(SDNode *N); - SDValue visitFMA(SDNode *N); + template <class MatchContextClass> SDValue visitFMA(SDNode *N); SDValue visitFDIV(SDNode *N); SDValue visitFREM(SDNode *N); SDValue visitFSQRT(SDNode *N); @@ -495,6 +508,7 @@ namespace { SDValue visitFABS(SDNode *N); SDValue visitFCEIL(SDNode *N); SDValue visitFTRUNC(SDNode *N); + SDValue visitFFREXP(SDNode *N); SDValue visitFFLOOR(SDNode *N); SDValue visitFMinMax(SDNode *N); SDValue visitBRCOND(SDNode *N); @@ -503,6 +517,7 @@ namespace { SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain); SDValue replaceStoreOfFPConstant(StoreSDNode *ST); + SDValue replaceStoreOfInsertLoad(StoreSDNode *ST); bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N); @@ -527,8 +542,12 @@ namespace { SDValue visitFP_TO_BF16(SDNode *N); SDValue visitVECREDUCE(SDNode *N); SDValue visitVPOp(SDNode *N); + SDValue visitGET_FPENV_MEM(SDNode *N); + SDValue visitSET_FPENV_MEM(SDNode *N); + template <class MatchContextClass> SDValue visitFADDForFMACombine(SDNode *N); + template <class MatchContextClass> SDValue visitFSUBForFMACombine(SDNode *N); SDValue visitFMULForFMADistributiveCombine(SDNode *N); @@ -539,9 +558,12 @@ namespace { SDValue N0, SDValue N1); SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0, - SDValue N1); + SDValue N1, SDNodeFlags Flags); SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, SDValue N1, SDNodeFlags Flags); + SDValue reassociateReduction(unsigned ResOpc, unsigned Opc, const SDLoc &DL, + EVT VT, SDValue N0, SDValue N1, + SDNodeFlags Flags = SDNodeFlags()); SDValue visitShiftByConstant(SDNode *N); @@ -579,11 +601,15 @@ namespace { SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp, unsigned HiOp); SDValue CombineConsecutiveLoads(SDNode *N, EVT VT); + SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG, + const TargetLowering &TLI); + SDValue CombineExtLoad(SDNode *N); SDValue CombineZExtLogicopShiftLoad(SDNode *N); SDValue combineRepeatedFPDivisors(SDNode *N); SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex); SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex); + SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex); SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT); SDValue BuildSDIV(SDNode *N); SDValue BuildSDIVPow2(SDNode *N); @@ -713,6 +739,11 @@ namespace { SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores); + /// Helper function for mergeConsecutiveStores which checks if all the store + /// nodes have the same underlying object. We can still reuse the first + /// store's pointer info if all the stores are from the same object. + bool hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes); + /// This is a helper function for mergeConsecutiveStores. When the source /// elements of the consecutive stores are all constants or all extracted /// vector elements, try to merge them into one larger store introducing @@ -841,6 +872,138 @@ public: void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); } }; +class EmptyMatchContext { + SelectionDAG &DAG; + const TargetLowering &TLI; + +public: + EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root) + : DAG(DAG), TLI(TLI) {} + + bool match(SDValue OpN, unsigned Opcode) const { + return Opcode == OpN->getOpcode(); + } + + // Same as SelectionDAG::getNode(). + template <typename... ArgT> SDValue getNode(ArgT &&...Args) { + return DAG.getNode(std::forward<ArgT>(Args)...); + } + + bool isOperationLegalOrCustom(unsigned Op, EVT VT, + bool LegalOnly = false) const { + return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly); + } +}; + +class VPMatchContext { + SelectionDAG &DAG; + const TargetLowering &TLI; + SDValue RootMaskOp; + SDValue RootVectorLenOp; + +public: + VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root) + : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() { + assert(Root->isVPOpcode()); + if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode())) + RootMaskOp = Root->getOperand(*RootMaskPos); + + if (auto RootVLenPos = + ISD::getVPExplicitVectorLengthIdx(Root->getOpcode())) + RootVectorLenOp = Root->getOperand(*RootVLenPos); + } + + /// whether \p OpVal is a node that is functionally compatible with the + /// NodeType \p Opc + bool match(SDValue OpVal, unsigned Opc) const { + if (!OpVal->isVPOpcode()) + return OpVal->getOpcode() == Opc; + + auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), + !OpVal->getFlags().hasNoFPExcept()); + if (BaseOpc != Opc) + return false; + + // Make sure the mask of OpVal is true mask or is same as Root's. + unsigned VPOpcode = OpVal->getOpcode(); + if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) { + SDValue MaskOp = OpVal.getOperand(*MaskPos); + if (RootMaskOp != MaskOp && + !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode())) + return false; + } + + // Make sure the EVL of OpVal is same as Root's. + if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode)) + if (RootVectorLenOp != OpVal.getOperand(*VLenPos)) + return false; + return true; + } + + // Specialize based on number of operands. + // TODO emit VP intrinsics where MaskOp/VectorLenOp != null + // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return + // DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) { + unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); + assert(ISD::getVPMaskIdx(VPOpcode) == 1 && + ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2); + return DAG.getNode(VPOpcode, DL, VT, + {Operand, RootMaskOp, RootVectorLenOp}); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2) { + unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); + assert(ISD::getVPMaskIdx(VPOpcode) == 2 && + ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3); + return DAG.getNode(VPOpcode, DL, VT, + {N1, N2, RootMaskOp, RootVectorLenOp}); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3) { + unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); + assert(ISD::getVPMaskIdx(VPOpcode) == 3 && + ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4); + return DAG.getNode(VPOpcode, DL, VT, + {N1, N2, N3, RootMaskOp, RootVectorLenOp}); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + SDNodeFlags Flags) { + unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); + assert(ISD::getVPMaskIdx(VPOpcode) == 1 && + ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2); + return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp}, + Flags); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDNodeFlags Flags) { + unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); + assert(ISD::getVPMaskIdx(VPOpcode) == 2 && + ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3); + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}, + Flags); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDNodeFlags Flags) { + unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); + assert(ISD::getVPMaskIdx(VPOpcode) == 3 && + ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4); + return DAG.getNode(VPOpcode, DL, VT, + {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags); + } + + bool isOperationLegalOrCustom(unsigned Op, EVT VT, + bool LegalOnly = false) const { + unsigned VPOp = ISD::getVPForBaseOpcode(Op); + return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -1099,7 +1262,8 @@ bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc, // Helper for DAGCombiner::reassociateOps. Try to reassociate an expression // such as (Opc N0, N1), if \p N0 is the same kind of operation as \p Opc. SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, - SDValue N0, SDValue N1) { + SDValue N0, SDValue N1, + SDNodeFlags Flags) { EVT VT = N0.getValueType(); if (N0.getOpcode() != Opc) @@ -1118,8 +1282,12 @@ SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, if (TLI.isReassocProfitable(DAG, N0, N1)) { // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1) // iff (op x, c1) has one use - SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1); - return DAG.getNode(Opc, DL, VT, OpNode, N01); + SDNodeFlags NewFlags; + if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() && + Flags.hasNoUnsignedWrap()) + NewFlags.setNoUnsignedWrap(true); + SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1, NewFlags); + return DAG.getNode(Opc, DL, VT, OpNode, N01, NewFlags); } } @@ -1177,13 +1345,32 @@ SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros()) return SDValue(); - if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1)) + if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1, Flags)) return Combined; - if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0)) + if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0, Flags)) return Combined; return SDValue(); } +// Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y)) +// Note that we only expect Flags to be passed from FP operations. For integer +// operations they need to be dropped. +SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc, + const SDLoc &DL, EVT VT, SDValue N0, + SDValue N1, SDNodeFlags Flags) { + if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc && + N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() && + N0->hasOneUse() && N1->hasOneUse() && + TLI.isOperationLegalOrCustom(Opc, N0.getOperand(0).getValueType()) && + TLI.shouldReassociateReduction(RedOpc, N0.getOperand(0).getValueType())) { + SelectionDAG::FlagInserter FlagsInserter(DAG, Flags); + return DAG.getNode(RedOpc, DL, VT, + DAG.getNode(Opc, DL, N0.getOperand(0).getValueType(), + N0.getOperand(0), N1.getOperand(0))); + } + return SDValue(); +} + SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo, bool AddTo) { assert(N->getNumValues() == NumTo && "Broken CombineTo call!"); @@ -1591,8 +1778,13 @@ void DAGCombiner::Run(CombineLevel AtLevel) { WorklistInserter AddNodes(*this); // Add all the dag nodes to the worklist. + // + // Note: All nodes are not added to PruningList here, this is because the only + // nodes which can be deleted are those which have no uses and all other nodes + // which would otherwise be added to the worklist by the first call to + // getNextWorklistEntry are already present in it. for (SDNode &Node : DAG.allnodes()) - AddToWorklist(&Node); + AddToWorklist(&Node, /* IsCandidateForPruning */ Node.use_empty()); // Create a dummy node (which is not added to allnodes), that adds a reference // to the root node, preventing it from being deleted, and tracking any @@ -1627,11 +1819,11 @@ void DAGCombiner::Run(CombineLevel AtLevel) { // Add any operands of the new node which have not yet been combined to the // worklist as well. Because the worklist uniques things already, this // won't repeatedly process the same operand. - CombinedNodes.insert(N); for (const SDValue &ChildN : N->op_values()) if (!CombinedNodes.count(ChildN.getNode())) AddToWorklist(ChildN.getNode()); + CombinedNodes.insert(N); SDValue RV = combine(N); if (!RV.getNode()) @@ -1665,10 +1857,8 @@ void DAGCombiner::Run(CombineLevel AtLevel) { // out), because re-visiting the EntryToken and its users will not uncover // any additional opportunities, but there may be a large number of such // users, potentially causing compile time explosion. - if (RV.getOpcode() != ISD::EntryToken) { - AddToWorklist(RV.getNode()); - AddUsersToWorklist(RV.getNode()); - } + if (RV.getOpcode() != ISD::EntryToken) + AddToWorklistWithUsers(RV.getNode()); // Finally, if the node is now dead, remove it from the graph. The node // may not be dead if the replacement process recursively simplified to @@ -1700,10 +1890,10 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::SSUBO: case ISD::USUBO: return visitSUBO(N); case ISD::ADDE: return visitADDE(N); - case ISD::ADDCARRY: return visitADDCARRY(N); + case ISD::UADDO_CARRY: return visitUADDO_CARRY(N); case ISD::SADDO_CARRY: return visitSADDO_CARRY(N); case ISD::SUBE: return visitSUBE(N); - case ISD::SUBCARRY: return visitSUBCARRY(N); + case ISD::USUBO_CARRY: return visitUSUBO_CARRY(N); case ISD::SSUBO_CARRY: return visitSSUBO_CARRY(N); case ISD::SMULFIX: case ISD::SMULFIXSAT: @@ -1720,6 +1910,8 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::AVGFLOORU: case ISD::AVGCEILS: case ISD::AVGCEILU: return visitAVG(N); + case ISD::ABDS: + case ISD::ABDU: return visitABD(N); case ISD::SMUL_LOHI: return visitSMUL_LOHI(N); case ISD::UMUL_LOHI: return visitUMUL_LOHI(N); case ISD::SMULO: @@ -1770,7 +1962,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::STRICT_FADD: return visitSTRICT_FADD(N); case ISD::FSUB: return visitFSUB(N); case ISD::FMUL: return visitFMUL(N); - case ISD::FMA: return visitFMA(N); + case ISD::FMA: return visitFMA<EmptyMatchContext>(N); case ISD::FDIV: return visitFDIV(N); case ISD::FREM: return visitFREM(N); case ISD::FSQRT: return visitFSQRT(N); @@ -1791,6 +1983,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::FMAXIMUM: return visitFMinMax(N); case ISD::FCEIL: return visitFCEIL(N); case ISD::FTRUNC: return visitFTRUNC(N); + case ISD::FFREXP: return visitFFREXP(N); case ISD::BRCOND: return visitBRCOND(N); case ISD::BR_CC: return visitBR_CC(N); case ISD::LOAD: return visitLOAD(N); @@ -1812,6 +2005,8 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::FP16_TO_FP: return visitFP16_TO_FP(N); case ISD::FP_TO_BF16: return visitFP_TO_BF16(N); case ISD::FREEZE: return visitFREEZE(N); + case ISD::GET_FPENV_MEM: return visitGET_FPENV_MEM(N); + case ISD::SET_FPENV_MEM: return visitSET_FPENV_MEM(N); case ISD::VECREDUCE_FADD: case ISD::VECREDUCE_FMUL: case ISD::VECREDUCE_ADD: @@ -1824,7 +2019,9 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::VECREDUCE_UMAX: case ISD::VECREDUCE_UMIN: case ISD::VECREDUCE_FMAX: - case ISD::VECREDUCE_FMIN: return visitVECREDUCE(N); + case ISD::VECREDUCE_FMIN: + case ISD::VECREDUCE_FMAXIMUM: + case ISD::VECREDUCE_FMINIMUM: return visitVECREDUCE(N); #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC: #include "llvm/IR/VPIntrinsics.def" return visitVPOp(N); @@ -2131,6 +2328,39 @@ static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) { return Const != nullptr && !Const->isOpaque() ? Const : nullptr; } +// isTruncateOf - If N is a truncate of some other value, return true, record +// the value being truncated in Op and which of Op's bits are zero/one in Known. +// This function computes KnownBits to avoid a duplicated call to +// computeKnownBits in the caller. +static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op, + KnownBits &Known) { + if (N->getOpcode() == ISD::TRUNCATE) { + Op = N->getOperand(0); + Known = DAG.computeKnownBits(Op); + return true; + } + + if (N.getOpcode() != ISD::SETCC || + N.getValueType().getScalarType() != MVT::i1 || + cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE) + return false; + + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + assert(Op0.getValueType() == Op1.getValueType()); + + if (isNullOrNullSplat(Op0)) + Op = Op1; + else if (isNullOrNullSplat(Op1)) + Op = Op0; + else + return false; + + Known = DAG.computeKnownBits(Op); + + return (Known.Zero | 1).isAllOnes(); +} + /// Return true if 'Use' is a load or a store that uses N as its base pointer /// and that N may be folded in the load / store addressing mode. static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG, @@ -2206,11 +2436,12 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG, if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse()) return SDValue(); - // We can't hoist div/rem because of immediate UB (not speculatable). - unsigned Opcode = N->getOpcode(); - if (!DAG.isSafeToSpeculativelyExecute(Opcode)) + // We can't hoist all instructions because of immediate UB (not speculatable). + // For example div/rem by zero. + if (!DAG.isSafeToSpeculativelyExecuteNode(N)) return SDValue(); + unsigned Opcode = N->getOpcode(); EVT VT = N->getValueType(0); SDValue Cond = N1.getOperand(0); SDValue TVal = N1.getOperand(1); @@ -2258,6 +2489,17 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) { if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) { SelOpNo = 1; Sel = BO->getOperand(1); + + // Peek through trunc to shift amount type. + if ((BinOpcode == ISD::SHL || BinOpcode == ISD::SRA || + BinOpcode == ISD::SRL) && Sel.hasOneUse()) { + // This is valid when the truncated bits of x are already zero. + SDValue Op; + KnownBits Known; + if (isTruncateOf(DAG, Sel, Op, Known) && + Known.countMaxActiveBits() < Sel.getScalarValueSizeInBits()) + Sel = Op; + } } if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) @@ -2310,18 +2552,14 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) { // constant. Eliminate the binop by pulling the constant math into the // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT + // CBO, CF + CBO - NewCT = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CT) - : DAG.getNode(BinOpcode, DL, VT, CT, CBO); - if (!CanFoldNonConst && !NewCT.isUndef() && - !isConstantOrConstantVector(NewCT, true) && - !DAG.isConstantFPBuildVectorOrConstantFP(NewCT)) + NewCT = SelOpNo ? DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CBO, CT}) + : DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CT, CBO}); + if (!NewCT) return SDValue(); - NewCF = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CF) - : DAG.getNode(BinOpcode, DL, VT, CF, CBO); - if (!CanFoldNonConst && !NewCF.isUndef() && - !isConstantOrConstantVector(NewCF, true) && - !DAG.isConstantFPBuildVectorOrConstantFP(NewCF)) + NewCF = SelOpNo ? DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CBO, CF}) + : DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CF, CBO}); + if (!NewCF) return SDValue(); } @@ -2420,6 +2658,12 @@ static bool isADDLike(SDValue V, const SelectionDAG &DAG) { return false; } +static bool +areBitwiseNotOfEachother(SDValue Op0, SDValue Op1) { + return (isBitwiseNot(Op0) && Op0.getOperand(0) == Op1) || + (isBitwiseNot(Op1) && Op1.getOperand(0) == Op0); +} + /// Try to fold a node that behaves like an ADD (note that N isn't necessarily /// an ISD::ADD here, it could for example be an ISD::OR if we know that there /// are no common bits set in the operands). @@ -2444,6 +2688,10 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { !DAG.isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::ADD, DL, VT, N1, N0); + if (areBitwiseNotOfEachother(N0, N1)) + return DAG.getConstant(APInt::getAllOnes(VT.getScalarSizeInBits()), + SDLoc(N), VT); + // fold vector ops if (VT.isVector()) { if (SDValue FoldedVOp = SimplifyVBinOp(N, DL)) @@ -2509,12 +2757,22 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { // equivalent to (add x, c). // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is // equivalent to (add x, c). + // Do this optimization only when adding c does not introduce instructions + // for adding carries. auto ReassociateAddOr = [&](SDValue N0, SDValue N1) { if (isADDLike(N0, DAG) && N0.hasOneUse() && isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) { - return DAG.getNode(ISD::ADD, DL, VT, - DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)), - N0.getOperand(1)); + // If N0's type does not split or is a sign mask, it does not introduce + // add carry. + auto TyActn = TLI.getTypeAction(*DAG.getContext(), N0.getValueType()); + bool NoAddCarry = TyActn == TargetLoweringBase::TypeLegal || + TyActn == TargetLoweringBase::TypePromoteInteger || + isMinSignedConstant(N0.getOperand(1)); + if (NoAddCarry) + return DAG.getNode( + ISD::ADD, DL, VT, + DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)), + N0.getOperand(1)); } return SDValue(); }; @@ -2522,6 +2780,11 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { return Add; if (SDValue Add = ReassociateAddOr(N1, N0)) return Add; + + // Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y)) + if (SDValue SD = + reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1)) + return SD; } // fold ((0-A) + B) -> B-A if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0))) @@ -2626,7 +2889,10 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { // And if the target does not like this form then turn into: // sub y, (xor x, -1) if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD && - N0.hasOneUse()) { + N0.hasOneUse() && + // Limit this to after legalization if the add has wrap flags + (Level >= AfterLegalizeDAG || (!N->getFlags().hasNoUnsignedWrap() && + !N->getFlags().hasNoSignedWrap()))) { SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0), DAG.getAllOnesConstant(DL, VT)); return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not); @@ -2714,6 +2980,7 @@ SDValue DAGCombiner::visitADDSAT(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); + bool IsSigned = Opcode == ISD::SADDSAT; SDLoc DL(N); // fold (add_sat x, undef) -> -1 @@ -2744,14 +3011,14 @@ SDValue DAGCombiner::visitADDSAT(SDNode *N) { return N0; // If it cannot overflow, transform into an add. - if (Opcode == ISD::UADDSAT) - if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never) - return DAG.getNode(ISD::ADD, DL, VT, N0, N1); + if (DAG.computeOverflowForAdd(IsSigned, N0, N1) == SelectionDAG::OFK_Never) + return DAG.getNode(ISD::ADD, DL, VT, N0, N1); return SDValue(); } -static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) { +static SDValue getAsCarry(const TargetLowering &TLI, SDValue V, + bool ForceCarryReconstruction = false) { bool Masked = false; // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization. @@ -2762,11 +3029,17 @@ static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) { } if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) { + if (ForceCarryReconstruction) + return V; + Masked = true; V = V.getOperand(0); continue; } + if (ForceCarryReconstruction && V.getValueType() == MVT::i1) + return V; + break; } @@ -2774,7 +3047,7 @@ static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) { if (V.getResNo() != 1) return SDValue(); - if (V.getOpcode() != ISD::ADDCARRY && V.getOpcode() != ISD::SUBCARRY && + if (V.getOpcode() != ISD::UADDO_CARRY && V.getOpcode() != ISD::USUBO_CARRY && V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO) return SDValue(); @@ -2842,7 +3115,10 @@ SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1, // And if the target does not like this form then turn into: // sub y, (xor x, -1) if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD && - N0.hasOneUse() && isOneOrOneSplat(N0.getOperand(1))) { + N0.hasOneUse() && isOneOrOneSplat(N0.getOperand(1)) && + // Limit this to after legalization if the add has wrap flags + (Level >= AfterLegalizeDAG || (!N0->getFlags().hasNoUnsignedWrap() && + !N0->getFlags().hasNoSignedWrap()))) { SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0), DAG.getAllOnesConstant(DL, VT)); return DAG.getNode(ISD::SUB, DL, VT, N1, Not); @@ -2864,6 +3140,15 @@ SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1, } } + // add (mul x, C), x -> mul x, C+1 + if (N0.getOpcode() == ISD::MUL && N0.getOperand(0) == N1 && + isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true) && + N0.hasOneUse()) { + SDValue NewC = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), + DAG.getConstant(1, DL, VT)); + return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), NewC); + } + // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1' // rather than 'add 0/-1' (the zext should get folded). // add (sext i1 Y), X --> sub X, (zext i1 Y) @@ -2884,16 +3169,16 @@ SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1, } } - // (add X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry) - if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1)) && + // (add X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry) + if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(N1.getOperand(1)) && N1.getResNo() == 0) - return DAG.getNode(ISD::ADDCARRY, DL, N1->getVTList(), + return DAG.getNode(ISD::UADDO_CARRY, DL, N1->getVTList(), N0, N1.getOperand(0), N1.getOperand(2)); - // (add X, Carry) -> (addcarry X, 0, Carry) - if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) + // (add X, Carry) -> (uaddo_carry X, 0, Carry) + if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT)) if (SDValue Carry = getAsCarry(TLI, N1)) - return DAG.getNode(ISD::ADDCARRY, DL, + return DAG.getNode(ISD::UADDO_CARRY, DL, DAG.getVTList(VT, Carry.getValueType()), N0, DAG.getConstant(0, DL, VT), Carry); @@ -2923,7 +3208,7 @@ SDValue DAGCombiner::visitADDC(SDNode *N) { DL, MVT::Glue)); // If it cannot overflow, transform into an add. - if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never) + if (DAG.computeOverflowForUnsignedAdd(N0, N1) == SelectionDAG::OFK_Never) return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1), DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue)); @@ -2995,12 +3280,12 @@ SDValue DAGCombiner::visitADDO(SDNode *N) { if (isNullOrNullSplat(N1)) return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT)); - if (!IsSigned) { - // If it cannot overflow, transform into an add. - if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never) - return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1), - DAG.getConstant(0, DL, CarryVT)); + // If it cannot overflow, transform into an add. + if (DAG.computeOverflowForAdd(IsSigned, N0, N1) == SelectionDAG::OFK_Never) + return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1), + DAG.getConstant(0, DL, CarryVT)); + if (!IsSigned) { // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry. if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) { SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(), @@ -3024,20 +3309,20 @@ SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) { if (VT.isVector()) return SDValue(); - // (uaddo X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry) + // (uaddo X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry) // If Y + 1 cannot overflow. - if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1))) { + if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(N1.getOperand(1))) { SDValue Y = N1.getOperand(0); SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType()); - if (DAG.computeOverflowKind(Y, One) == SelectionDAG::OFK_Never) - return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, Y, + if (DAG.computeOverflowForUnsignedAdd(Y, One) == SelectionDAG::OFK_Never) + return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0, Y, N1.getOperand(2)); } - // (uaddo X, Carry) -> (addcarry X, 0, Carry) - if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) + // (uaddo X, Carry) -> (uaddo_carry X, 0, Carry) + if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT)) if (SDValue Carry = getAsCarry(TLI, N1)) - return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, + return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0, DAG.getConstant(0, SDLoc(N), VT), Carry); return SDValue(); @@ -3062,7 +3347,7 @@ SDValue DAGCombiner::visitADDE(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitADDCARRY(SDNode *N) { +SDValue DAGCombiner::visitUADDO_CARRY(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); SDValue CarryIn = N->getOperand(2); @@ -3072,16 +3357,16 @@ SDValue DAGCombiner::visitADDCARRY(SDNode *N) { ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0); ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); if (N0C && !N1C) - return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn); + return DAG.getNode(ISD::UADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn); - // fold (addcarry x, y, false) -> (uaddo x, y) + // fold (uaddo_carry x, y, false) -> (uaddo x, y) if (isNullConstant(CarryIn)) { if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0))) return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1); } - // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry. + // fold (uaddo_carry 0, 0, X) -> (and (ext/trunc X), 1) and no carry. if (isNullConstant(N0) && isNullConstant(N1)) { EVT VT = N0.getValueType(); EVT CarryVT = CarryIn.getValueType(); @@ -3092,73 +3377,52 @@ SDValue DAGCombiner::visitADDCARRY(SDNode *N) { DAG.getConstant(0, DL, CarryVT)); } - if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N)) + if (SDValue Combined = visitUADDO_CARRYLike(N0, N1, CarryIn, N)) return Combined; - if (SDValue Combined = visitADDCARRYLike(N1, N0, CarryIn, N)) + if (SDValue Combined = visitUADDO_CARRYLike(N1, N0, CarryIn, N)) return Combined; // We want to avoid useless duplication. - // TODO: This is done automatically for binary operations. As ADDCARRY is + // TODO: This is done automatically for binary operations. As UADDO_CARRY is // not a binary operation, this is not really possible to leverage this // existing mechanism for it. However, if more operations require the same // deduplication logic, then it may be worth generalize. SDValue Ops[] = {N1, N0, CarryIn}; SDNode *CSENode = - DAG.getNodeIfExists(ISD::ADDCARRY, N->getVTList(), Ops, N->getFlags()); + DAG.getNodeIfExists(ISD::UADDO_CARRY, N->getVTList(), Ops, N->getFlags()); if (CSENode) return SDValue(CSENode, 0); return SDValue(); } -SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - SDValue CarryIn = N->getOperand(2); - SDLoc DL(N); - - // canonicalize constant to RHS - ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0); - ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); - if (N0C && !N1C) - return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn); - - // fold (saddo_carry x, y, false) -> (saddo x, y) - if (isNullConstant(CarryIn)) { - if (!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0))) - return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1); - } - - return SDValue(); -} - /** * If we are facing some sort of diamond carry propapagtion pattern try to * break it up to generate something like: - * (addcarry X, 0, (addcarry A, B, Z):Carry) + * (uaddo_carry X, 0, (uaddo_carry A, B, Z):Carry) * * The end result is usually an increase in operation required, but because the * carry is now linearized, other transforms can kick in and optimize the DAG. * * Patterns typically look something like - * (uaddo A, B) - * / \ - * Carry Sum - * | \ - * | (addcarry *, 0, Z) - * | / - * \ Carry - * | / - * (addcarry X, *, *) + * (uaddo A, B) + * / \ + * Carry Sum + * | \ + * | (uaddo_carry *, 0, Z) + * | / + * \ Carry + * | / + * (uaddo_carry X, *, *) * * But numerous variation exist. Our goal is to identify A, B, X and Z and * produce a combine with a single path for carry propagation. */ -static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG, - SDValue X, SDValue Carry0, SDValue Carry1, - SDNode *N) { +static SDValue combineUADDO_CARRYDiamond(DAGCombiner &Combiner, + SelectionDAG &DAG, SDValue X, + SDValue Carry0, SDValue Carry1, + SDNode *N) { if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1) return SDValue(); if (Carry1.getOpcode() != ISD::UADDO) @@ -3168,9 +3432,9 @@ static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG, /** * First look for a suitable Z. It will present itself in the form of - * (addcarry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true + * (uaddo_carry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true */ - if (Carry0.getOpcode() == ISD::ADDCARRY && + if (Carry0.getOpcode() == ISD::UADDO_CARRY && isNullConstant(Carry0.getOperand(1))) { Z = Carry0.getOperand(2); } else if (Carry0.getOpcode() == ISD::UADDO && @@ -3185,26 +3449,27 @@ static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG, auto cancelDiamond = [&](SDValue A,SDValue B) { SDLoc DL(N); - SDValue NewY = DAG.getNode(ISD::ADDCARRY, DL, Carry0->getVTList(), A, B, Z); + SDValue NewY = + DAG.getNode(ISD::UADDO_CARRY, DL, Carry0->getVTList(), A, B, Z); Combiner.AddToWorklist(NewY.getNode()); - return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), X, + return DAG.getNode(ISD::UADDO_CARRY, DL, N->getVTList(), X, DAG.getConstant(0, DL, X.getValueType()), NewY.getValue(1)); }; /** - * (uaddo A, B) - * | - * Sum - * | - * (addcarry *, 0, Z) + * (uaddo A, B) + * | + * Sum + * | + * (uaddo_carry *, 0, Z) */ if (Carry0.getOperand(0) == Carry1.getValue(0)) { return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1)); } /** - * (addcarry A, 0, Z) + * (uaddo_carry A, 0, Z) * | * Sum * | @@ -3241,12 +3506,12 @@ static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG, // | / // CarryOut = (or *, *) // -// And generate ADDCARRY (or SUBCARRY) with two result values: +// And generate UADDO_CARRY (or USUBO_CARRY) with two result values: // -// {AddCarrySum, CarryOut} = (addcarry A, B, CarryIn) +// {AddCarrySum, CarryOut} = (uaddo_carry A, B, CarryIn) // -// Our goal is to identify A, B, and CarryIn and produce ADDCARRY/SUBCARRY with -// a single path for carry/borrow out propagation: +// Our goal is to identify A, B, and CarryIn and produce UADDO_CARRY/USUBO_CARRY +// with a single path for carry/borrow out propagation. static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI, SDValue N0, SDValue N1, SDNode *N) { SDValue Carry0 = getAsCarry(TLI, N0); @@ -3279,16 +3544,13 @@ static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI, return SDValue(); SDValue CarryIn = Carry1.getOperand(CarryInOperandNum); - unsigned NewOp = Opcode == ISD::UADDO ? ISD::ADDCARRY : ISD::SUBCARRY; + unsigned NewOp = Opcode == ISD::UADDO ? ISD::UADDO_CARRY : ISD::USUBO_CARRY; if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType())) return SDValue(); // Verify that the carry/borrow in is plausibly a carry/borrow bit. - // TODO: make getAsCarry() aware of how partial carries are merged. - if (CarryIn.getOpcode() != ISD::ZERO_EXTEND) - return SDValue(); - CarryIn = CarryIn.getOperand(0); - if (CarryIn.getValueType() != MVT::i1) + CarryIn = getAsCarry(TLI, CarryIn, true); + if (!CarryIn) return SDValue(); SDLoc DL(N); @@ -3315,45 +3577,68 @@ static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI, return Merged.getValue(1); } -SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, - SDNode *N) { - // fold (addcarry (xor a, -1), b, c) -> (subcarry b, a, !c) and flip carry. +SDValue DAGCombiner::visitUADDO_CARRYLike(SDValue N0, SDValue N1, + SDValue CarryIn, SDNode *N) { + // fold (uaddo_carry (xor a, -1), b, c) -> (usubo_carry b, a, !c) and flip + // carry. if (isBitwiseNot(N0)) if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) { SDLoc DL(N); - SDValue Sub = DAG.getNode(ISD::SUBCARRY, DL, N->getVTList(), N1, + SDValue Sub = DAG.getNode(ISD::USUBO_CARRY, DL, N->getVTList(), N1, N0.getOperand(0), NotC); return CombineTo( N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1))); } // Iff the flag result is dead: - // (addcarry (add|uaddo X, Y), 0, Carry) -> (addcarry X, Y, Carry) + // (uaddo_carry (add|uaddo X, Y), 0, Carry) -> (uaddo_carry X, Y, Carry) // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo // or the dependency between the instructions. if ((N0.getOpcode() == ISD::ADD || (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 && N0.getValue(1) != CarryIn)) && isNullConstant(N1) && !N->hasAnyUseOfValue(1)) - return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), + return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0.getOperand(0), N0.getOperand(1), CarryIn); /** - * When one of the addcarry argument is itself a carry, we may be facing + * When one of the uaddo_carry argument is itself a carry, we may be facing * a diamond carry propagation. In which case we try to transform the DAG * to ensure linear carry propagation if that is possible. */ if (auto Y = getAsCarry(TLI, N1)) { // Because both are carries, Y and Z can be swapped. - if (auto R = combineADDCARRYDiamond(*this, DAG, N0, Y, CarryIn, N)) + if (auto R = combineUADDO_CARRYDiamond(*this, DAG, N0, Y, CarryIn, N)) return R; - if (auto R = combineADDCARRYDiamond(*this, DAG, N0, CarryIn, Y, N)) + if (auto R = combineUADDO_CARRYDiamond(*this, DAG, N0, CarryIn, Y, N)) return R; } return SDValue(); } +SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + SDValue CarryIn = N->getOperand(2); + SDLoc DL(N); + + // canonicalize constant to RHS + ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0); + ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); + if (N0C && !N1C) + return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn); + + // fold (saddo_carry x, y, false) -> (saddo x, y) + if (isNullConstant(CarryIn)) { + if (!LegalOperations || + TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0))) + return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1); + } + + return SDValue(); +} + // Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a // clamp/truncation if necessary. static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS, @@ -3720,11 +4005,6 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { // If the relocation model supports it, consider symbol offsets. if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0)) if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) { - // fold (sub Sym, c) -> Sym-c - if (N1C && GA->getOpcode() == ISD::GlobalAddress) - return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT, - GA->getOffset() - - (uint64_t)N1C->getSExtValue()); // fold (sub Sym+c1, Sym+c2) -> c1-c2 if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1)) if (GA->getGlobal() == GB->getGlobal()) @@ -3776,19 +4056,19 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { return DAG.getNode(ISD::ADD, DL, VT, N1, N0); } - // (sub (subcarry X, 0, Carry), Y) -> (subcarry X, Y, Carry) - if (N0.getOpcode() == ISD::SUBCARRY && isNullConstant(N0.getOperand(1)) && + // (sub (usubo_carry X, 0, Carry), Y) -> (usubo_carry X, Y, Carry) + if (N0.getOpcode() == ISD::USUBO_CARRY && isNullConstant(N0.getOperand(1)) && N0.getResNo() == 0 && N0.hasOneUse()) - return DAG.getNode(ISD::SUBCARRY, DL, N0->getVTList(), + return DAG.getNode(ISD::USUBO_CARRY, DL, N0->getVTList(), N0.getOperand(0), N1, N0.getOperand(2)); - if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) { - // (sub Carry, X) -> (addcarry (sub 0, X), 0, Carry) + if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT)) { + // (sub Carry, X) -> (uaddo_carry (sub 0, X), 0, Carry) if (SDValue Carry = getAsCarry(TLI, N0)) { SDValue X = N1; SDValue Zero = DAG.getConstant(0, DL, VT); SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X); - return DAG.getNode(ISD::ADDCARRY, DL, + return DAG.getNode(ISD::UADDO_CARRY, DL, DAG.getVTList(VT, Carry.getValueType()), NegX, Zero, Carry); } @@ -3814,7 +4094,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { (N0.getOperand(0) != N1.getOperand(1) || N0.getOperand(1) != N1.getOperand(0))) return SDValue(); - if (!TLI.isOperationLegalOrCustom(Abd, VT)) + if (!hasOperation(Abd, VT)) return SDValue(); return DAG.getNode(Abd, DL, VT, N0.getOperand(0), N0.getOperand(1)); }; @@ -3827,9 +4107,11 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { } SDValue DAGCombiner::visitSUBSAT(SDNode *N) { + unsigned Opcode = N->getOpcode(); SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); + bool IsSigned = Opcode == ISD::SSUBSAT; SDLoc DL(N); // fold (sub_sat x, undef) -> 0 @@ -3841,7 +4123,7 @@ SDValue DAGCombiner::visitSUBSAT(SDNode *N) { return DAG.getConstant(0, DL, VT); // fold (sub_sat c1, c2) -> c3 - if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1})) + if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1})) return C; // fold vector ops @@ -3858,6 +4140,10 @@ SDValue DAGCombiner::visitSUBSAT(SDNode *N) { if (isNullConstant(N1)) return N0; + // If it cannot overflow, transform into an sub. + if (DAG.computeOverflowForSub(IsSigned, N0, N1) == SelectionDAG::OFK_Never) + return DAG.getNode(ISD::SUB, DL, VT, N0, N1); + return SDValue(); } @@ -3911,7 +4197,7 @@ SDValue DAGCombiner::visitSUBO(SDNode *N) { ConstantSDNode *N1C = getAsNonOpaqueConstant(N1); // fold (subox, c) -> (addo x, -c) - if (IsSigned && N1C && !N1C->getAPIntValue().isMinSignedValue()) { + if (IsSigned && N1C && !N1C->isMinSignedValue()) { return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, DAG.getConstant(-N1C->getAPIntValue(), DL, VT)); } @@ -3920,6 +4206,11 @@ SDValue DAGCombiner::visitSUBO(SDNode *N) { if (isNullOrNullSplat(N1)) return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT)); + // If it cannot overflow, transform into an sub. + if (DAG.computeOverflowForSub(IsSigned, N0, N1) == SelectionDAG::OFK_Never) + return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1), + DAG.getConstant(0, DL, CarryVT)); + // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow if (!IsSigned && isAllOnesOrAllOnesSplat(N0)) return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0), @@ -3940,12 +4231,12 @@ SDValue DAGCombiner::visitSUBE(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitSUBCARRY(SDNode *N) { +SDValue DAGCombiner::visitUSUBO_CARRY(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); SDValue CarryIn = N->getOperand(2); - // fold (subcarry x, y, false) -> (usubo x, y) + // fold (usubo_carry x, y, false) -> (usubo x, y) if (isNullConstant(CarryIn)) { if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0))) @@ -4062,13 +4353,14 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) { unsigned Log2Val = (-ConstValue1).logBase2(); + EVT ShiftVT = getShiftAmountTy(N0.getValueType()); + // FIXME: If the input is something that is easily negated (e.g. a // single-use add), we should put the negate there. return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), DAG.getNode(ISD::SHL, DL, VT, N0, - DAG.getConstant(Log2Val, DL, - getShiftAmountTy(N0.getValueType())))); + DAG.getConstant(Log2Val, DL, ShiftVT))); } // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the @@ -4108,7 +4400,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { unsigned MathOp = ISD::DELETED_NODE; APInt MulC = ConstValue1.abs(); // The constant `2` should be treated as (2^0 + 1). - unsigned TZeros = MulC == 2 ? 0 : MulC.countTrailingZeros(); + unsigned TZeros = MulC == 2 ? 0 : MulC.countr_zero(); MulC.lshrInPlace(TZeros); if ((MulC - 1).isPowerOf2()) MathOp = ISD::ADD; @@ -4163,8 +4455,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { } // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2) - if (DAG.isConstantIntBuildVectorOrConstantInt(N1) && - N0.getOpcode() == ISD::ADD && + if (N0.getOpcode() == ISD::ADD && + DAG.isConstantIntBuildVectorOrConstantInt(N1) && DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) && isMulAddWithConstProfitable(N, N0, N1)) return DAG.getNode( @@ -4223,6 +4515,11 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags())) return RMUL; + // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y)) + if (SDValue SD = + reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1)) + return SD; + // Simplify the operands using demanded-bits information. if (SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); @@ -4386,7 +4683,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { return DAG.getNegative(N0, DL, VT); // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0) - if (N1C && N1C->getAPIntValue().isMinSignedValue()) + if (N1C && N1C->isMinSignedValue()) return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ), DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT)); @@ -4886,11 +5183,57 @@ SDValue DAGCombiner::visitAVG(SDNode *N) { if (N1.isUndef()) return N0; + // Fold (avg x, x) --> x + if (N0 == N1 && Level >= AfterLegalizeTypes) + return N0; + // TODO If we use avg for scalars anywhere, we can add (avgfl x, 0) -> x >> 1 return SDValue(); } +SDValue DAGCombiner::visitABD(SDNode *N) { + unsigned Opcode = N->getOpcode(); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + // fold (abd c1, c2) + if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1})) + return C; + + // canonicalize constant to RHS. + if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && + !DAG.isConstantIntBuildVectorOrConstantInt(N1)) + return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0); + + if (VT.isVector()) { + if (SDValue FoldedVOp = SimplifyVBinOp(N, DL)) + return FoldedVOp; + + // fold (abds x, 0) -> abs x + // fold (abdu x, 0) -> x + if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) { + if (Opcode == ISD::ABDS) + return DAG.getNode(ISD::ABS, DL, VT, N0); + if (Opcode == ISD::ABDU) + return N0; + } + } + + // fold (abd x, undef) -> 0 + if (N0.isUndef() || N1.isUndef()) + return DAG.getConstant(0, DL, VT); + + // fold (abds x, y) -> (abdu x, y) iff both args are known positive + if (Opcode == ISD::ABDS && hasOperation(ISD::ABDU, VT) && + DAG.SignBitIsZero(N0) && DAG.SignBitIsZero(N1)) + return DAG.getNode(ISD::ABDU, DL, VT, N1, N0); + + return SDValue(); +} + /// Perform optimizations common to nodes that compute two values. LoOp and HiOp /// give the opcodes for the two computations that are being performed. Return /// true if a simplification was made. @@ -5108,7 +5451,7 @@ SDValue DAGCombiner::visitMULO(SDNode *N) { // same as SimplifySelectCC. N0<N1 ? N2 : N3. static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2, SDValue N3, ISD::CondCode CC, unsigned &BW, - bool &Unsigned) { + bool &Unsigned, SelectionDAG &DAG) { auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3, ISD::CondCode CC) { // The compare and select operand should be the same or the select operands @@ -5132,6 +5475,26 @@ static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2, if (!Opcode0) return SDValue(); + // We could only need one range check, if the fptosi could never produce + // the upper value. + if (N0.getOpcode() == ISD::FP_TO_SINT && Opcode0 == ISD::SMAX) { + if (isNullOrNullSplat(N3)) { + EVT IntVT = N0.getValueType().getScalarType(); + EVT FPVT = N0.getOperand(0).getValueType().getScalarType(); + if (FPVT.isSimple()) { + Type *InputTy = FPVT.getTypeForEVT(*DAG.getContext()); + const fltSemantics &Semantics = InputTy->getFltSemantics(); + uint32_t MinBitWidth = + APFloatBase::semanticsIntSizeInBits(Semantics, /*isSigned*/ true); + if (IntVT.getSizeInBits() >= MinBitWidth) { + Unsigned = true; + BW = PowerOf2Ceil(MinBitWidth); + return N0; + } + } + } + } + SDValue N00, N01, N02, N03; ISD::CondCode N0CC; switch (N0.getOpcode()) { @@ -5194,7 +5557,7 @@ static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2, SelectionDAG &DAG) { unsigned BW; bool Unsigned; - SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned); + SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned, DAG); if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT) return SDValue(); EVT FPVT = Fp.getOperand(0).getValueType(); @@ -5208,8 +5571,7 @@ static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2, SDLoc DL(Fp); SDValue Sat = DAG.getNode(NewOpc, DL, NewVT, Fp.getOperand(0), DAG.getValueType(NewVT.getScalarType())); - return Unsigned ? DAG.getZExtOrTrunc(Sat, DL, N2->getValueType(0)) - : DAG.getSExtOrTrunc(Sat, DL, N2->getValueType(0)); + return DAG.getExtOrTrunc(!Unsigned, Sat, DL, N2->getValueType(0)); } static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2, @@ -5298,6 +5660,25 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) { if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG)) return S; + // Fold min/max(vecreduce(x), vecreduce(y)) -> vecreduce(min/max(x, y)) + auto ReductionOpcode = [](unsigned Opcode) { + switch (Opcode) { + case ISD::SMIN: + return ISD::VECREDUCE_SMIN; + case ISD::SMAX: + return ISD::VECREDUCE_SMAX; + case ISD::UMIN: + return ISD::VECREDUCE_UMIN; + case ISD::UMAX: + return ISD::VECREDUCE_UMAX; + default: + llvm_unreachable("Unexpected opcode"); + } + }; + if (SDValue SD = reassociateReduction(ReductionOpcode(Opcode), Opcode, + SDLoc(N), VT, N0, N1)) + return SD; + // Simplify the operands using demanded-bits information. if (SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); @@ -5312,8 +5693,7 @@ SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) { EVT VT = N0.getValueType(); unsigned LogicOpcode = N->getOpcode(); unsigned HandOpcode = N0.getOpcode(); - assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR || - LogicOpcode == ISD::XOR) && "Expected logic opcode"); + assert(ISD::isBitwiseLogicOp(LogicOpcode) && "Expected logic opcode"); assert(HandOpcode == N1.getOpcode() && "Bad input!"); // Bail early if none of these transforms apply. @@ -5323,13 +5703,14 @@ SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) { // FIXME: We should check number of uses of the operands to not increase // the instruction count for all transforms. - // Handle size-changing casts. + // Handle size-changing casts (or sign_extend_inreg). SDValue X = N0.getOperand(0); SDValue Y = N1.getOperand(0); EVT XVT = X.getValueType(); SDLoc DL(N); - if (HandOpcode == ISD::ANY_EXTEND || HandOpcode == ISD::ZERO_EXTEND || - HandOpcode == ISD::SIGN_EXTEND) { + if (ISD::isExtOpcode(HandOpcode) || ISD::isExtVecInRegOpcode(HandOpcode) || + (HandOpcode == ISD::SIGN_EXTEND_INREG && + N0.getOperand(1) == N1.getOperand(1))) { // If both operands have other uses, this transform would create extra // instructions without eliminating anything. if (!N0.hasOneUse() && !N1.hasOneUse()) @@ -5344,11 +5725,14 @@ SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) { return SDValue(); // Avoid infinite looping with PromoteIntBinOp. // TODO: Should we apply desirable/legal constraints to all opcodes? - if (HandOpcode == ISD::ANY_EXTEND && LegalTypes && - !TLI.isTypeDesirableForOp(LogicOpcode, XVT)) + if ((HandOpcode == ISD::ANY_EXTEND || + HandOpcode == ISD::ANY_EXTEND_VECTOR_INREG) && + LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, XVT)) return SDValue(); // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y) SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); + if (HandOpcode == ISD::SIGN_EXTEND_INREG) + return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1)); return DAG.getNode(HandOpcode, DL, VT, Logic); } @@ -5629,6 +6013,172 @@ SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1, return SDValue(); } +static SDValue foldAndOrOfSETCC(SDNode *LogicOp, SelectionDAG &DAG) { + using AndOrSETCCFoldKind = TargetLowering::AndOrSETCCFoldKind; + assert( + (LogicOp->getOpcode() == ISD::AND || LogicOp->getOpcode() == ISD::OR) && + "Invalid Op to combine SETCC with"); + + // TODO: Search past casts/truncates. + SDValue LHS = LogicOp->getOperand(0); + SDValue RHS = LogicOp->getOperand(1); + if (LHS->getOpcode() != ISD::SETCC || RHS->getOpcode() != ISD::SETCC) + return SDValue(); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + AndOrSETCCFoldKind TargetPreference = TLI.isDesirableToCombineLogicOpOfSETCC( + LogicOp, LHS.getNode(), RHS.getNode()); + + SDValue LHS0 = LHS->getOperand(0); + SDValue RHS0 = RHS->getOperand(0); + SDValue LHS1 = LHS->getOperand(1); + SDValue RHS1 = RHS->getOperand(1); + // TODO: We don't actually need a splat here, for vectors we just need the + // invariants to hold for each element. + auto *LHS1C = isConstOrConstSplat(LHS1); + auto *RHS1C = isConstOrConstSplat(RHS1); + ISD::CondCode CCL = cast<CondCodeSDNode>(LHS.getOperand(2))->get(); + ISD::CondCode CCR = cast<CondCodeSDNode>(RHS.getOperand(2))->get(); + EVT VT = LogicOp->getValueType(0); + EVT OpVT = LHS0.getValueType(); + SDLoc DL(LogicOp); + + // Check if the operands of an and/or operation are comparisons and if they + // compare against the same value. Replace the and/or-cmp-cmp sequence with + // min/max cmp sequence. If LHS1 is equal to RHS1, then the or-cmp-cmp + // sequence will be replaced with min-cmp sequence: + // (LHS0 < LHS1) | (RHS0 < RHS1) -> min(LHS0, RHS0) < LHS1 + // and and-cmp-cmp will be replaced with max-cmp sequence: + // (LHS0 < LHS1) & (RHS0 < RHS1) -> max(LHS0, RHS0) < LHS1 + if (OpVT.isInteger() && TLI.isOperationLegal(ISD::UMAX, OpVT) && + TLI.isOperationLegal(ISD::SMAX, OpVT) && + TLI.isOperationLegal(ISD::UMIN, OpVT) && + TLI.isOperationLegal(ISD::SMIN, OpVT)) { + if (LHS->getOpcode() == ISD::SETCC && RHS->getOpcode() == ISD::SETCC && + LHS->hasOneUse() && RHS->hasOneUse() && + // The two comparisons should have either the same predicate or the + // predicate of one of the comparisons is the opposite of the other one. + (CCL == CCR || CCL == ISD::getSetCCSwappedOperands(CCR)) && + // The optimization does not work for `==` or `!=` . + !ISD::isIntEqualitySetCC(CCL) && !ISD::isIntEqualitySetCC(CCR)) { + SDValue CommonValue, Operand1, Operand2; + ISD::CondCode CC = ISD::SETCC_INVALID; + if (CCL == CCR) { + if (LHS0 == RHS0) { + CommonValue = LHS0; + Operand1 = LHS1; + Operand2 = RHS1; + CC = ISD::getSetCCSwappedOperands(CCL); + } else if (LHS1 == RHS1) { + CommonValue = LHS1; + Operand1 = LHS0; + Operand2 = RHS0; + CC = CCL; + } + } else { + assert(CCL == ISD::getSetCCSwappedOperands(CCR) && "Unexpected CC"); + if (LHS0 == RHS1) { + CommonValue = LHS0; + Operand1 = LHS1; + Operand2 = RHS0; + CC = ISD::getSetCCSwappedOperands(CCL); + } else if (RHS0 == LHS1) { + CommonValue = LHS1; + Operand1 = LHS0; + Operand2 = RHS1; + CC = CCL; + } + } + + if (CC != ISD::SETCC_INVALID) { + unsigned NewOpcode; + bool IsSigned = isSignedIntSetCC(CC); + if (((CC == ISD::SETLE || CC == ISD::SETULE || CC == ISD::SETLT || + CC == ISD::SETULT) && + (LogicOp->getOpcode() == ISD::OR)) || + ((CC == ISD::SETGE || CC == ISD::SETUGE || CC == ISD::SETGT || + CC == ISD::SETUGT) && + (LogicOp->getOpcode() == ISD::AND))) + NewOpcode = IsSigned ? ISD::SMIN : ISD::UMIN; + else + NewOpcode = IsSigned ? ISD::SMAX : ISD::UMAX; + + SDValue MinMaxValue = + DAG.getNode(NewOpcode, DL, OpVT, Operand1, Operand2); + return DAG.getSetCC(DL, VT, MinMaxValue, CommonValue, CC); + } + } + } + + if (TargetPreference == AndOrSETCCFoldKind::None) + return SDValue(); + + if (CCL == CCR && + CCL == (LogicOp->getOpcode() == ISD::AND ? ISD::SETNE : ISD::SETEQ) && + LHS0 == RHS0 && LHS1C && RHS1C && OpVT.isInteger() && LHS.hasOneUse() && + RHS.hasOneUse()) { + const APInt &APLhs = LHS1C->getAPIntValue(); + const APInt &APRhs = RHS1C->getAPIntValue(); + + // Preference is to use ISD::ABS or we already have an ISD::ABS (in which + // case this is just a compare). + if (APLhs == (-APRhs) && + ((TargetPreference & AndOrSETCCFoldKind::ABS) || + DAG.doesNodeExist(ISD::ABS, DAG.getVTList(OpVT), {LHS0}))) { + const APInt &C = APLhs.isNegative() ? APRhs : APLhs; + // (icmp eq A, C) | (icmp eq A, -C) + // -> (icmp eq Abs(A), C) + // (icmp ne A, C) & (icmp ne A, -C) + // -> (icmp ne Abs(A), C) + SDValue AbsOp = DAG.getNode(ISD::ABS, DL, OpVT, LHS0); + return DAG.getNode(ISD::SETCC, DL, VT, AbsOp, + DAG.getConstant(C, DL, OpVT), LHS.getOperand(2)); + } else if (TargetPreference & + (AndOrSETCCFoldKind::AddAnd | AndOrSETCCFoldKind::NotAnd)) { + + // AndOrSETCCFoldKind::AddAnd: + // A == C0 | A == C1 + // IF IsPow2(smax(C0, C1)-smin(C0, C1)) + // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) == 0 + // A != C0 & A != C1 + // IF IsPow2(smax(C0, C1)-smin(C0, C1)) + // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) != 0 + + // AndOrSETCCFoldKind::NotAnd: + // A == C0 | A == C1 + // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1)) + // -> ~A & smin(C0, C1) == 0 + // A != C0 & A != C1 + // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1)) + // -> ~A & smin(C0, C1) != 0 + + const APInt &MaxC = APIntOps::smax(APRhs, APLhs); + const APInt &MinC = APIntOps::smin(APRhs, APLhs); + APInt Dif = MaxC - MinC; + if (!Dif.isZero() && Dif.isPowerOf2()) { + if (MaxC.isAllOnes() && + (TargetPreference & AndOrSETCCFoldKind::NotAnd)) { + SDValue NotOp = DAG.getNOT(DL, LHS0, OpVT); + SDValue AndOp = DAG.getNode(ISD::AND, DL, OpVT, NotOp, + DAG.getConstant(MinC, DL, OpVT)); + return DAG.getNode(ISD::SETCC, DL, VT, AndOp, + DAG.getConstant(0, DL, OpVT), LHS.getOperand(2)); + } else if (TargetPreference & AndOrSETCCFoldKind::AddAnd) { + + SDValue AddOp = DAG.getNode(ISD::ADD, DL, OpVT, LHS0, + DAG.getConstant(-MinC, DL, OpVT)); + SDValue AndOp = DAG.getNode(ISD::AND, DL, OpVT, AddOp, + DAG.getConstant(~Dif, DL, OpVT)); + return DAG.getNode(ISD::SETCC, DL, VT, AndOp, + DAG.getConstant(0, DL, OpVT), LHS.getOperand(2)); + } + } + } + } + + return SDValue(); +} + /// This contains all DAGCombine rules which reduce two values combined by /// an And operation to a single value. This makes them reusable in the context /// of visitSELECT(). Rules involving constants are not included as @@ -5644,6 +6194,11 @@ SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) { if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL)) return V; + // Canonicalize: + // and(x, add) -> and(add, x) + if (N1.getOpcode() == ISD::ADD) + std::swap(N0, N1); + // TODO: Rewrite this to return a new 'AND' instead of using CombineTo. if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL && VT.getSizeInBits() <= 64 && N0->hasOneUse()) { @@ -5655,8 +6210,7 @@ SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) { // in a register. APInt ADDC = ADDI->getAPIntValue(); APInt SRLC = SRLI->getAPIntValue(); - if (ADDC.getMinSignedBits() <= 64 && - SRLC.ult(VT.getSizeInBits()) && + if (ADDC.getSignificantBits() <= 64 && SRLC.ult(VT.getSizeInBits()) && !TLI.isLegalAddImmediate(ADDC.getSExtValue())) { APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(), SRLC.getZExtValue()); @@ -5677,55 +6231,6 @@ SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) { } } - // Reduce bit extract of low half of an integer to the narrower type. - // (and (srl i64:x, K), KMask) -> - // (i64 zero_extend (and (srl (i32 (trunc i64:x)), K)), KMask) - if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) { - if (ConstantSDNode *CAnd = dyn_cast<ConstantSDNode>(N1)) { - if (ConstantSDNode *CShift = dyn_cast<ConstantSDNode>(N0.getOperand(1))) { - unsigned Size = VT.getSizeInBits(); - const APInt &AndMask = CAnd->getAPIntValue(); - unsigned ShiftBits = CShift->getZExtValue(); - - // Bail out, this node will probably disappear anyway. - if (ShiftBits == 0) - return SDValue(); - - unsigned MaskBits = AndMask.countTrailingOnes(); - EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), Size / 2); - - if (AndMask.isMask() && - // Required bits must not span the two halves of the integer and - // must fit in the half size type. - (ShiftBits + MaskBits <= Size / 2) && - TLI.isNarrowingProfitable(VT, HalfVT) && - TLI.isTypeDesirableForOp(ISD::AND, HalfVT) && - TLI.isTypeDesirableForOp(ISD::SRL, HalfVT) && - TLI.isTruncateFree(VT, HalfVT) && - TLI.isZExtFree(HalfVT, VT)) { - // The isNarrowingProfitable is to avoid regressions on PPC and - // AArch64 which match a few 64-bit bit insert / bit extract patterns - // on downstream users of this. Those patterns could probably be - // extended to handle extensions mixed in. - - SDValue SL(N0); - assert(MaskBits <= Size); - - // Extracting the highest bit of the low half. - EVT ShiftVT = TLI.getShiftAmountTy(HalfVT, DAG.getDataLayout()); - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, HalfVT, - N0.getOperand(0)); - - SDValue NewMask = DAG.getConstant(AndMask.trunc(Size / 2), SL, HalfVT); - SDValue ShiftK = DAG.getConstant(ShiftBits, SL, ShiftVT); - SDValue Shift = DAG.getNode(ISD::SRL, SL, HalfVT, Trunc, ShiftK); - SDValue And = DAG.getNode(ISD::AND, SL, HalfVT, Shift, NewMask); - return DAG.getNode(ISD::ZERO_EXTEND, SL, VT, And); - } - } - } - } - return SDValue(); } @@ -5734,7 +6239,7 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, if (!AndC->getAPIntValue().isMask()) return false; - unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes(); + unsigned ActiveBits = AndC->getAPIntValue().countr_one(); ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits); EVT LoadedVT = LoadN->getMemoryVT(); @@ -5898,7 +6403,7 @@ bool DAGCombiner::SearchForAndLoads(SDNode *N, } case ISD::ZERO_EXTEND: case ISD::AssertZext: { - unsigned ActiveBits = Mask->getAPIntValue().countTrailingOnes(); + unsigned ActiveBits = Mask->getAPIntValue().countr_one(); EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits); EVT VT = Op.getOpcode() == ISD::AssertZext ? cast<VTSDNode>(Op.getOperand(1))->getVT() : @@ -6071,12 +6576,6 @@ SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) { static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) { assert(And->getOpcode() == ISD::AND && "Expected an 'and' op"); - // This is probably not worthwhile without a supported type. - EVT VT = And->getValueType(0); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (!TLI.isTypeLegal(VT)) - return SDValue(); - // Look through an optional extension. SDValue And0 = And->getOperand(0), And1 = And->getOperand(1); if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse()) @@ -6104,13 +6603,17 @@ static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) { if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse()) return SDValue(); + // This is probably not worthwhile without a supported type. + EVT SrcVT = Src.getValueType(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!TLI.isTypeLegal(SrcVT)) + return SDValue(); + // We might have looked through casts that make this transform invalid. - // TODO: If the source type is wider than the result type, do the mask and - // compare in the source type. - unsigned VTBitWidth = VT.getScalarSizeInBits(); + unsigned BitWidth = SrcVT.getScalarSizeInBits(); SDValue ShiftAmt = Src.getOperand(1); auto *ShiftAmtC = dyn_cast<ConstantSDNode>(ShiftAmt); - if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(VTBitWidth)) + if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(BitWidth)) return SDValue(); // Set source to shift source. @@ -6131,14 +6634,15 @@ static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) { // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0 // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0 SDLoc DL(And); - SDValue X = DAG.getZExtOrTrunc(Src, DL, VT); - EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue X = DAG.getZExtOrTrunc(Src, DL, SrcVT); + EVT CCVT = + TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT); SDValue Mask = DAG.getConstant( - APInt::getOneBitSet(VTBitWidth, ShiftAmtC->getZExtValue()), DL, VT); - SDValue NewAnd = DAG.getNode(ISD::AND, DL, VT, X, Mask); - SDValue Zero = DAG.getConstant(0, DL, VT); + APInt::getOneBitSet(BitWidth, ShiftAmtC->getZExtValue()), DL, SrcVT); + SDValue NewAnd = DAG.getNode(ISD::AND, DL, SrcVT, X, Mask); + SDValue Zero = DAG.getConstant(0, DL, SrcVT); SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ); - return DAG.getZExtOrTrunc(Setcc, DL, VT); + return DAG.getZExtOrTrunc(Setcc, DL, And->getValueType(0)); } /// For targets that support usubsat, match a bit-hack form of that operation @@ -6181,9 +6685,8 @@ static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG) { static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp, SelectionDAG &DAG) { unsigned LogicOpcode = N->getOpcode(); - assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR || - LogicOpcode == ISD::XOR) - && "Expected bitwise logic operation"); + assert(ISD::isBitwiseLogicOp(LogicOpcode) && + "Expected bitwise logic operation"); if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse()) return SDValue(); @@ -6230,8 +6733,8 @@ static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp, static SDValue foldLogicTreeOfShifts(SDNode *N, SDValue LeftHand, SDValue RightHand, SelectionDAG &DAG) { unsigned LogicOpcode = N->getOpcode(); - assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR || - LogicOpcode == ISD::XOR)); + assert(ISD::isBitwiseLogicOp(LogicOpcode) && + "Expected bitwise logic operation"); if (LeftHand.getOpcode() != LogicOpcode || RightHand.getOpcode() != LogicOpcode) return SDValue(); @@ -6276,6 +6779,10 @@ SDValue DAGCombiner::visitAND(SDNode *N) { !DAG.isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0); + if (areBitwiseNotOfEachother(N0, N1)) + return DAG.getConstant(APInt::getZero(VT.getScalarSizeInBits()), SDLoc(N), + VT); + // fold vector ops if (VT.isVector()) { if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N))) @@ -6330,6 +6837,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(BitWidth))) return DAG.getConstant(0, SDLoc(N), VT); + if (SDValue R = foldAndOrOfSETCC(N, DAG)) + return R; + if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; @@ -6337,6 +6847,11 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags())) return RAND; + // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y)) + if (SDValue SD = reassociateReduction(ISD::VECREDUCE_AND, ISD::AND, SDLoc(N), + VT, N0, N1)) + return SD; + // fold (and (or x, C), D) -> D if (C & D) == D auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) { return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue()); @@ -6345,13 +6860,27 @@ SDValue DAGCombiner::visitAND(SDNode *N) { ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset)) return N1; - // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits. if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) { SDValue N0Op0 = N0.getOperand(0); + EVT SrcVT = N0Op0.getValueType(); + unsigned SrcBitWidth = SrcVT.getScalarSizeInBits(); APInt Mask = ~N1C->getAPIntValue(); - Mask = Mask.trunc(N0Op0.getScalarValueSizeInBits()); + Mask = Mask.trunc(SrcBitWidth); + + // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits. if (DAG.MaskedValueIsZero(N0Op0, Mask)) - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N0.getValueType(), N0Op0); + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, N0Op0); + + // fold (and (any_ext V), c) -> (zero_ext (and (trunc V), c)) if profitable. + if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) && + TLI.isTruncateFree(VT, SrcVT) && TLI.isZExtFree(SrcVT, VT) && + TLI.isTypeDesirableForOp(ISD::AND, SrcVT) && + TLI.isNarrowingProfitable(VT, SrcVT)) { + SDLoc DL(N); + return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, + DAG.getNode(ISD::AND, DL, SrcVT, N0Op0, + DAG.getZExtOrTrunc(N1, DL, SrcVT))); + } } // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2))) @@ -7046,24 +7575,39 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) { static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1, SDNode *N) { EVT VT = N0.getValueType(); - if (N0.getOpcode() == ISD::AND) { - SDValue N00 = N0.getOperand(0); - SDValue N01 = N0.getOperand(1); + + auto peekThroughResize = [](SDValue V) { + if (V->getOpcode() == ISD::ZERO_EXTEND || V->getOpcode() == ISD::TRUNCATE) + return V->getOperand(0); + return V; + }; + + SDValue N0Resized = peekThroughResize(N0); + if (N0Resized.getOpcode() == ISD::AND) { + SDValue N1Resized = peekThroughResize(N1); + SDValue N00 = N0Resized.getOperand(0); + SDValue N01 = N0Resized.getOperand(1); // fold or (and x, y), x --> x - if (N00 == N1 || N01 == N1) + if (N00 == N1Resized || N01 == N1Resized) return N1; // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y) // TODO: Set AllowUndefs = true. - if (getBitwiseNotOperand(N01, N00, - /* AllowUndefs */ false) == N1) - return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N1); + if (SDValue NotOperand = getBitwiseNotOperand(N01, N00, + /* AllowUndefs */ false)) { + if (peekThroughResize(NotOperand) == N1Resized) + return DAG.getNode(ISD::OR, SDLoc(N), VT, + DAG.getZExtOrTrunc(N00, SDLoc(N), VT), N1); + } // fold (or (and (xor Y, -1), X), Y) -> (or X, Y) - if (getBitwiseNotOperand(N00, N01, - /* AllowUndefs */ false) == N1) - return DAG.getNode(ISD::OR, SDLoc(N), VT, N01, N1); + if (SDValue NotOperand = getBitwiseNotOperand(N00, N01, + /* AllowUndefs */ false)) { + if (peekThroughResize(NotOperand) == N1Resized) + return DAG.getNode(ISD::OR, SDLoc(N), VT, + DAG.getZExtOrTrunc(N01, SDLoc(N), VT), N1); + } } if (N0.getOpcode() == ISD::XOR) { @@ -7215,6 +7759,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) { if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue())) return N1; + if (SDValue R = foldAndOrOfSETCC(N, DAG)) + return R; + if (SDValue Combined = visitORLike(N0, N1, N)) return Combined; @@ -7231,6 +7778,11 @@ SDValue DAGCombiner::visitOR(SDNode *N) { if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags())) return ROR; + // Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y)) + if (SDValue SD = reassociateReduction(ISD::VECREDUCE_OR, ISD::OR, SDLoc(N), + VT, N0, N1)) + return SD; + // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2) // iff (c1 & c2) != 0 or c1/c2 are undef. auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) { @@ -7898,42 +8450,6 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { return SDValue(); } -namespace { - -/// Represents known origin of an individual byte in load combine pattern. The -/// value of the byte is either constant zero or comes from memory. -struct ByteProvider { - // For constant zero providers Load is set to nullptr. For memory providers - // Load represents the node which loads the byte from memory. - // ByteOffset is the offset of the byte in the value produced by the load. - LoadSDNode *Load = nullptr; - unsigned ByteOffset = 0; - unsigned VectorOffset = 0; - - ByteProvider() = default; - - static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset, - unsigned VectorOffset) { - return ByteProvider(Load, ByteOffset, VectorOffset); - } - - static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0, 0); } - - bool isConstantZero() const { return !Load; } - bool isMemory() const { return Load; } - - bool operator==(const ByteProvider &Other) const { - return Other.Load == Load && Other.ByteOffset == ByteOffset && - Other.VectorOffset == VectorOffset; - } - -private: - ByteProvider(LoadSDNode *Load, unsigned ByteOffset, unsigned VectorOffset) - : Load(Load), ByteOffset(ByteOffset), VectorOffset(VectorOffset) {} -}; - -} // end anonymous namespace - /// Recursively traverses the expression calculating the origin of the requested /// byte of the given value. Returns std::nullopt if the provider can't be /// calculated. @@ -7975,7 +8491,9 @@ private: /// LOAD /// /// *ExtractVectorElement -static const std::optional<ByteProvider> +using SDByteProvider = ByteProvider<SDNode *>; + +static const std::optional<SDByteProvider> calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, std::optional<uint64_t> VectorIndex, unsigned StartingIndex = 0) { @@ -8034,7 +8552,7 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, // provide, then do not provide anything. Otherwise, subtract the index by // the amount we shifted by. return Index < ByteShift - ? ByteProvider::getConstantZero() + ? SDByteProvider::getConstantZero() : calculateByteProvider(Op->getOperand(0), Index - ByteShift, Depth + 1, VectorIndex, Index); } @@ -8049,7 +8567,8 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, if (Index >= NarrowByteWidth) return Op.getOpcode() == ISD::ZERO_EXTEND - ? std::optional<ByteProvider>(ByteProvider::getConstantZero()) + ? std::optional<SDByteProvider>( + SDByteProvider::getConstantZero()) : std::nullopt; return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex, StartingIndex); @@ -8099,11 +8618,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, // question if (Index >= NarrowByteWidth) return L->getExtensionType() == ISD::ZEXTLOAD - ? std::optional<ByteProvider>(ByteProvider::getConstantZero()) + ? std::optional<SDByteProvider>( + SDByteProvider::getConstantZero()) : std::nullopt; unsigned BPVectorIndex = VectorIndex.value_or(0U); - return ByteProvider::getMemory(L, Index, BPVectorIndex); + return SDByteProvider::getSrc(L, Index, BPVectorIndex); } } @@ -8191,9 +8711,12 @@ SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) { !N->isSimple() || N->isIndexed()) return SDValue(); - // Collect all of the stores in the chain. + // Collect all of the stores in the chain, upto the maximum store width (i64). SDValue Chain = N->getChain(); SmallVector<StoreSDNode *, 8> Stores = {N}; + unsigned NarrowNumBits = MemVT.getScalarSizeInBits(); + unsigned MaxWideNumBits = 64; + unsigned MaxStores = MaxWideNumBits / NarrowNumBits; while (auto *Store = dyn_cast<StoreSDNode>(Chain)) { // All stores must be the same size to ensure that we are writing all of the // bytes in the wide value. @@ -8207,6 +8730,8 @@ SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) { return SDValue(); Stores.push_back(Store); Chain = Store->getChain(); + if (MaxStores < Stores.size()) + return SDValue(); } // There is no reason to continue if we do not have at least a pair of stores. if (Stores.size() < 2) @@ -8215,7 +8740,6 @@ SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) { // Handle simple types only. LLVMContext &Context = *DAG.getContext(); unsigned NumStores = Stores.size(); - unsigned NarrowNumBits = N->getMemoryVT().getScalarSizeInBits(); unsigned WideNumBits = NumStores * NarrowNumBits; EVT WideVT = EVT::getIntegerVT(Context, WideNumBits); if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64) @@ -8397,23 +8921,24 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { unsigned ByteWidth = VT.getSizeInBits() / 8; bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian(); - auto MemoryByteOffset = [&] (ByteProvider P) { - assert(P.isMemory() && "Must be a memory byte provider"); - unsigned LoadBitWidth = P.Load->getMemoryVT().getScalarSizeInBits(); + auto MemoryByteOffset = [&](SDByteProvider P) { + assert(P.hasSrc() && "Must be a memory byte provider"); + auto *Load = cast<LoadSDNode>(P.Src.value()); + + unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits(); assert(LoadBitWidth % 8 == 0 && "can only analyze providers for individual bytes not bit"); unsigned LoadByteWidth = LoadBitWidth / 8; - return IsBigEndianTarget - ? bigEndianByteAt(LoadByteWidth, P.ByteOffset) - : littleEndianByteAt(LoadByteWidth, P.ByteOffset); + return IsBigEndianTarget ? bigEndianByteAt(LoadByteWidth, P.DestOffset) + : littleEndianByteAt(LoadByteWidth, P.DestOffset); }; std::optional<BaseIndexOffset> Base; SDValue Chain; SmallPtrSet<LoadSDNode *, 8> Loads; - std::optional<ByteProvider> FirstByteProvider; + std::optional<SDByteProvider> FirstByteProvider; int64_t FirstOffset = INT64_MAX; // Check if all the bytes of the OR we are looking at are loaded from the same @@ -8434,9 +8959,8 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { return SDValue(); continue; } - assert(P->isMemory() && "provenance should either be memory or zero"); - - LoadSDNode *L = P->Load; + assert(P->hasSrc() && "provenance should either be memory or zero"); + auto *L = cast<LoadSDNode>(P->Src.value()); // All loads must share the same chain SDValue LChain = L->getChain(); @@ -8460,7 +8984,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits(); if (LoadWidthInBit % 8 != 0) return SDValue(); - unsigned ByteOffsetFromVector = P->VectorOffset * LoadWidthInBit / 8; + unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8; Ptr.addToOffset(ByteOffsetFromVector); } @@ -8517,7 +9041,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { // So the combined value can be loaded from the first load address. if (MemoryByteOffset(*FirstByteProvider) != 0) return SDValue(); - LoadSDNode *FirstLoad = FirstByteProvider->Load; + auto *FirstLoad = cast<LoadSDNode>(FirstByteProvider->Src.value()); // The node we are looking at matches with the pattern, check if we can // replace it with a single (possibly zero-extended) load and bswap + shift if @@ -8715,6 +9239,11 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags())) return RXOR; + // Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y)) + if (SDValue SD = + reassociateReduction(ISD::VECREDUCE_XOR, ISD::XOR, DL, VT, N0, N1)) + return SD; + // fold (a^b) -> (a|b) iff a and b share no bits. if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) && DAG.haveNoCommonBitsSet(N0, N1)) @@ -9462,7 +9991,7 @@ static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG, SDValue MulhRightOp; if (ConstantSDNode *Constant = isConstOrConstSplat(RightOp)) { unsigned ActiveBits = IsSignExt - ? Constant->getAPIntValue().getMinSignedBits() + ? Constant->getAPIntValue().getSignificantBits() : Constant->getAPIntValue().getActiveBits(); if (ActiveBits > NarrowVTSize) return SDValue(); @@ -9499,14 +10028,59 @@ static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG, // we use mulhs. Othewise, zero extends (zext) use mulhu. unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU; - // Combine to mulh if mulh is legal/custom for the narrow type on the target. - if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT)) - return SDValue(); + // Combine to mulh if mulh is legal/custom for the narrow type on the target + // or if it is a vector type then we could transform to an acceptable type and + // rely on legalization to split/combine the result. + if (NarrowVT.isVector()) { + EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), NarrowVT); + if (TransformVT.getVectorElementType() != NarrowVT.getVectorElementType() || + !TLI.isOperationLegalOrCustom(MulhOpcode, TransformVT)) + return SDValue(); + } else { + if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT)) + return SDValue(); + } SDValue Result = DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0), MulhRightOp); - return (N->getOpcode() == ISD::SRA ? DAG.getSExtOrTrunc(Result, DL, WideVT) - : DAG.getZExtOrTrunc(Result, DL, WideVT)); + bool IsSigned = N->getOpcode() == ISD::SRA; + return DAG.getExtOrTrunc(IsSigned, Result, DL, WideVT); +} + +// fold (bswap (logic_op(bswap(x),y))) -> logic_op(x,bswap(y)) +// This helper function accept SDNode with opcode ISD::BSWAP and ISD::BITREVERSE +static SDValue foldBitOrderCrossLogicOp(SDNode *N, SelectionDAG &DAG) { + unsigned Opcode = N->getOpcode(); + if (Opcode != ISD::BSWAP && Opcode != ISD::BITREVERSE) + return SDValue(); + + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + SDLoc DL(N); + if (ISD::isBitwiseLogicOp(N0.getOpcode()) && N0.hasOneUse()) { + SDValue OldLHS = N0.getOperand(0); + SDValue OldRHS = N0.getOperand(1); + + // If both operands are bswap/bitreverse, ignore the multiuse + // Otherwise need to ensure logic_op and bswap/bitreverse(x) have one use. + if (OldLHS.getOpcode() == Opcode && OldRHS.getOpcode() == Opcode) { + return DAG.getNode(N0.getOpcode(), DL, VT, OldLHS.getOperand(0), + OldRHS.getOperand(0)); + } + + if (OldLHS.getOpcode() == Opcode && OldLHS.hasOneUse()) { + SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, OldRHS); + return DAG.getNode(N0.getOpcode(), DL, VT, OldLHS.getOperand(0), + NewBitReorder); + } + + if (OldRHS.getOpcode() == Opcode && OldRHS.hasOneUse()) { + SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, OldLHS); + return DAG.getNode(N0.getOpcode(), DL, VT, NewBitReorder, + OldRHS.getOperand(0)); + } + } + return SDValue(); } SDValue DAGCombiner::visitSRA(SDNode *N) { @@ -9892,8 +10466,10 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1); } - // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit). + // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit), and x has a power + // of two bitwidth. The "5" represents (log2 (bitwidth x)). if (N1C && N0.getOpcode() == ISD::CTLZ && + isPowerOf2_32(OpSizeInBits) && N1C->getAPIntValue() == Log2_32(OpSizeInBits)) { KnownBits Known = DAG.computeKnownBits(N0.getOperand(0)); @@ -9912,7 +10488,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { // could be set on input to the CTLZ node. If this bit is set, the SRL // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair // to an SRL/XOR pair, which is likely to simplify more. - unsigned ShAmt = UnknownBits.countTrailingZeros(); + unsigned ShAmt = UnknownBits.countr_zero(); SDValue Op = N0.getOperand(0); if (ShAmt) { @@ -10138,13 +10714,23 @@ SDValue DAGCombiner::visitSHLSAT(SDNode *N) { return SDValue(); } -// Given a ABS node, detect the following pattern: +// Given a ABS node, detect the following patterns: // (ABS (SUB (EXTEND a), (EXTEND b))). +// (TRUNC (ABS (SUB (EXTEND a), (EXTEND b)))). // Generates UABD/SABD instruction. SDValue DAGCombiner::foldABSToABD(SDNode *N) { + EVT SrcVT = N->getValueType(0); + + if (N->getOpcode() == ISD::TRUNCATE) + N = N->getOperand(0).getNode(); + + if (N->getOpcode() != ISD::ABS) + return SDValue(); + EVT VT = N->getValueType(0); SDValue AbsOp1 = N->getOperand(0); SDValue Op0, Op1; + SDLoc DL(N); if (AbsOp1.getOpcode() != ISD::SUB) return SDValue(); @@ -10157,9 +10743,11 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N) { if (Opc0 != Op1.getOpcode() || (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND)) { // fold (abs (sub nsw x, y)) -> abds(x, y) - if (AbsOp1->getFlags().hasNoSignedWrap() && - TLI.isOperationLegalOrCustom(ISD::ABDS, VT)) - return DAG.getNode(ISD::ABDS, SDLoc(N), VT, Op0, Op1); + if (AbsOp1->getFlags().hasNoSignedWrap() && hasOperation(ISD::ABDS, VT) && + TLI.preferABDSToABSWithNSW(VT)) { + SDValue ABD = DAG.getNode(ISD::ABDS, DL, VT, Op0, Op1); + return DAG.getZExtOrTrunc(ABD, DL, SrcVT); + } return SDValue(); } @@ -10170,17 +10758,20 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N) { // fold abs(sext(x) - sext(y)) -> zext(abds(x, y)) // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y)) // NOTE: Extensions must be equivalent. - if (VT1 == VT2 && TLI.isOperationLegalOrCustom(ABDOpcode, VT1)) { + if (VT1 == VT2 && hasOperation(ABDOpcode, VT1)) { Op0 = Op0.getOperand(0); Op1 = Op1.getOperand(0); - SDValue ABD = DAG.getNode(ABDOpcode, SDLoc(N), VT1, Op0, Op1); - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, ABD); + SDValue ABD = DAG.getNode(ABDOpcode, DL, VT1, Op0, Op1); + ABD = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ABD); + return DAG.getZExtOrTrunc(ABD, DL, SrcVT); } // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y)) // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y)) - if (TLI.isOperationLegalOrCustom(ABDOpcode, VT)) - return DAG.getNode(ABDOpcode, SDLoc(N), VT, Op0, Op1); + if (hasOperation(ABDOpcode, VT)) { + SDValue ABD = DAG.getNode(ABDOpcode, DL, VT, Op0, Op1); + return DAG.getZExtOrTrunc(ABD, DL, SrcVT); + } return SDValue(); } @@ -10190,8 +10781,8 @@ SDValue DAGCombiner::visitABS(SDNode *N) { EVT VT = N->getValueType(0); // fold (abs c1) -> c2 - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) - return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0); + if (SDValue C = DAG.FoldConstantArithmetic(ISD::ABS, SDLoc(N), VT, {N0})) + return C; // fold (abs (abs x)) -> (abs x) if (N0.getOpcode() == ISD::ABS) return N0; @@ -10277,6 +10868,9 @@ SDValue DAGCombiner::visitBSWAP(SDNode *N) { } } + if (SDValue V = foldBitOrderCrossLogicOp(N, DAG)) + return V; + return SDValue(); } @@ -10447,7 +11041,8 @@ SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, if (NegRHS == False) { SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, NegTrue, False, CC, TLI, DAG); - return DAG.getNode(ISD::FNEG, DL, VT, Combined); + if (Combined) + return DAG.getNode(ISD::FNEG, DL, VT, Combined); } } } @@ -11091,6 +11686,23 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) { if (ISD::isConstantSplatVectorAllZeros(Mask.getNode())) return Chain; + // Remove a masked store if base pointers and masks are equal. + if (MaskedStoreSDNode *MST1 = dyn_cast<MaskedStoreSDNode>(Chain)) { + if (MST->isUnindexed() && MST->isSimple() && MST1->isUnindexed() && + MST1->isSimple() && MST1->getBasePtr() == Ptr && + !MST->getBasePtr().isUndef() && + ((Mask == MST1->getMask() && MST->getMemoryVT().getStoreSize() == + MST1->getMemoryVT().getStoreSize()) || + ISD::isConstantSplatVectorAllOnes(Mask.getNode())) && + TypeSize::isKnownLE(MST1->getMemoryVT().getStoreSize(), + MST->getMemoryVT().getStoreSize())) { + CombineTo(MST1, MST1->getChain()); + if (N->getOpcode() != ISD::DELETED_NODE) + AddToWorklist(N); + return SDValue(N, 0); + } + } + // If this is a masked load with an all ones mask, we can use a unmasked load. // FIXME: Can we do this for indexed, compressing, or truncating stores? if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MST->isUnindexed() && @@ -11391,6 +12003,38 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { } } + // Match VSELECTs with absolute difference patterns. + // (vselect (setcc a, b, set?gt), (sub a, b), (sub b, a)) --> (abd? a, b) + // (vselect (setcc a, b, set?ge), (sub a, b), (sub b, a)) --> (abd? a, b) + // (vselect (setcc a, b, set?lt), (sub b, a), (sub a, b)) --> (abd? a, b) + // (vselect (setcc a, b, set?le), (sub b, a), (sub a, b)) --> (abd? a, b) + if (N1.getOpcode() == ISD::SUB && N2.getOpcode() == ISD::SUB && + N1.getOperand(0) == N2.getOperand(1) && + N1.getOperand(1) == N2.getOperand(0)) { + bool IsSigned = isSignedIntSetCC(CC); + unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU; + if (hasOperation(ABDOpc, VT)) { + switch (CC) { + case ISD::SETGT: + case ISD::SETGE: + case ISD::SETUGT: + case ISD::SETUGE: + if (LHS == N1.getOperand(0) && RHS == N1.getOperand(1)) + return DAG.getNode(ABDOpc, DL, VT, LHS, RHS); + break; + case ISD::SETLT: + case ISD::SETLE: + case ISD::SETULT: + case ISD::SETULE: + if (RHS == N1.getOperand(0) && LHS == N1.getOperand(1) ) + return DAG.getNode(ABDOpc, DL, VT, LHS, RHS); + break; + default: + break; + } + } + } + // Match VSELECTs into add with unsigned saturation. if (hasOperation(ISD::UADDSAT, VT)) { // Check if one of the arms of the VSELECT is vector with all bits set. @@ -11612,57 +12256,6 @@ SDValue DAGCombiner::visitSETCC(SDNode *N) { ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get(); EVT VT = N->getValueType(0); - // SETCC(FREEZE(X), CONST, Cond) - // => - // FREEZE(SETCC(X, CONST, Cond)) - // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond) - // isn't equivalent to true or false. - // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to - // FREEZE(SETCC(X, -128, SETULT)) because X can be poison. - // - // This transformation is beneficial because visitBRCOND can fold - // BRCOND(FREEZE(X)) to BRCOND(X). - - // Conservatively optimize integer comparisons only. - if (PreferSetCC) { - // Do this only when SETCC is going to be used by BRCOND. - - SDValue N0 = N->getOperand(0), N1 = N->getOperand(1); - ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0); - ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); - bool Updated = false; - - // Is 'X Cond C' always true or false? - auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) { - bool False = (Cond == ISD::SETULT && C->isZero()) || - (Cond == ISD::SETLT && C->isMinSignedValue()) || - (Cond == ISD::SETUGT && C->isAllOnes()) || - (Cond == ISD::SETGT && C->isMaxSignedValue()); - bool True = (Cond == ISD::SETULE && C->isAllOnes()) || - (Cond == ISD::SETLE && C->isMaxSignedValue()) || - (Cond == ISD::SETUGE && C->isZero()) || - (Cond == ISD::SETGE && C->isMinSignedValue()); - return True || False; - }; - - if (N0->getOpcode() == ISD::FREEZE && N0.hasOneUse() && N1C) { - if (!IsAlwaysTrueOrFalse(Cond, N1C)) { - N0 = N0->getOperand(0); - Updated = true; - } - } - if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse() && N0C) { - if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Cond), - N0C)) { - N1 = N1->getOperand(0); - Updated = true; - } - } - - if (Updated) - return DAG.getFreeze(DAG.getSetCC(SDLoc(N), VT, N0, N1, Cond)); - } - SDValue Combined = SimplifySetCC(VT, N->getOperand(0), N->getOperand(1), Cond, SDLoc(N), !PreferSetCC); @@ -11733,7 +12326,8 @@ static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) { /// This function is called by the DAGCombiner when visiting sext/zext/aext /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND). static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI, - SelectionDAG &DAG) { + SelectionDAG &DAG, + CombineLevel Level) { unsigned Opcode = N->getOpcode(); SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); @@ -11758,10 +12352,14 @@ static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI, else if (Opcode == ISD::ZERO_EXTEND) ExtLoadOpcode = ISD::ZEXTLOAD; + // Illegal VSELECT may ISel fail if happen after legalization (DAG + // Combine2), so we should conservatively check the OperationAction. LoadSDNode *Load1 = cast<LoadSDNode>(Op1); LoadSDNode *Load2 = cast<LoadSDNode>(Op2); if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) || - !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT())) + !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()) || + (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes && + TLI.getOperationAction(ISD::VSELECT, VT) != TargetLowering::Legal)) return SDValue(); SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1); @@ -11782,11 +12380,7 @@ static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI, EVT VT = N->getValueType(0); SDLoc DL(N); - assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND || - Opcode == ISD::ANY_EXTEND || - Opcode == ISD::SIGN_EXTEND_VECTOR_INREG || - Opcode == ISD::ZERO_EXTEND_VECTOR_INREG || - Opcode == ISD::ANY_EXTEND_VECTOR_INREG) && + assert((ISD::isExtOpcode(Opcode) || ISD::isExtVecInRegOpcode(Opcode)) && "Expected EXTEND dag node in input!"); // fold (sext c1) -> c1 @@ -12052,8 +12646,7 @@ SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) { // and/or/xor SDValue N0 = N->getOperand(0); - if (!(N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR || - N0.getOpcode() == ISD::XOR) || + if (!ISD::isBitwiseLogicOp(N0.getOpcode()) || N0.getOperand(1).getOpcode() != ISD::Constant || (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT))) return SDValue(); @@ -12449,11 +13042,19 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0)); + // fold (sext (aext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x) + // fold (sext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x) + if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG || + N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG) + return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT, + N0.getOperand(0)); + // fold (sext (sext_inreg x)) -> (sext (trunc x)) if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) { SDValue N00 = N0.getOperand(0); EVT ExtVT = cast<VTSDNode>(N0->getOperand(1))->getVT(); - if (N00.getOpcode() == ISD::TRUNCATE && (!LegalOperations || TLI.isTypeLegal(ExtVT))) { + if (N00.getOpcode() == ISD::TRUNCATE && + (!LegalTypes || TLI.isTypeLegal(ExtVT))) { SDValue T = DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N00.getOperand(0)); return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, T); } @@ -12532,8 +13133,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { // fold (sext (and/or/xor (load x), cst)) -> // (and/or/xor (sextload x), (sext cst)) - if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR || - N0.getOpcode() == ISD::XOR) && + if (ISD::isBitwiseLogicOp(N0.getOpcode()) && isa<LoadSDNode>(N0.getOperand(0)) && N0.getOperand(1).getOpcode() == ISD::Constant && (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) { @@ -12630,45 +13230,12 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT)); } - if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, Level)) return Res; return SDValue(); } -// isTruncateOf - If N is a truncate of some other value, return true, record -// the value being truncated in Op and which of Op's bits are zero/one in Known. -// This function computes KnownBits to avoid a duplicated call to -// computeKnownBits in the caller. -static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op, - KnownBits &Known) { - if (N->getOpcode() == ISD::TRUNCATE) { - Op = N->getOperand(0); - Known = DAG.computeKnownBits(Op); - return true; - } - - if (N.getOpcode() != ISD::SETCC || - N.getValueType().getScalarType() != MVT::i1 || - cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE) - return false; - - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - assert(Op0.getValueType() == Op1.getValueType()); - - if (isNullOrNullSplat(Op0)) - Op = Op1; - else if (isNullOrNullSplat(Op1)) - Op = Op0; - else - return false; - - Known = DAG.computeKnownBits(Op); - - return (Known.Zero | 1).isAllOnes(); -} - /// Given an extending node with a pop-count operand, if the target does not /// support a pop-count in the narrow source type but does support it in the /// destination type, widen the pop-count to the destination type. @@ -12722,14 +13289,15 @@ static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) { SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); + SDLoc DL(N); if (VT.isVector()) - if (SDValue FoldedVOp = SimplifyVCastOp(N, SDLoc(N))) + if (SDValue FoldedVOp = SimplifyVCastOp(N, DL)) return FoldedVOp; // zext(undef) = 0 if (N0.isUndef()) - return DAG.getConstant(0, SDLoc(N), VT); + return DAG.getConstant(0, DL, VT); if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) return Res; @@ -12737,7 +13305,13 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { // fold (zext (zext x)) -> (zext x) // fold (zext (aext x)) -> (zext x) if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, + return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)); + + // fold (zext (aext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x) + // fold (zext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x) + if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG || + N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) + return DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, SDLoc(N), VT, N0.getOperand(0)); // fold (zext (truncate x)) -> (zext x) or @@ -12754,7 +13328,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { std::min(Op.getScalarValueSizeInBits(), VT.getScalarSizeInBits())); if (TruncatedBits.isSubsetOf(Known.Zero)) - return DAG.getZExtOrTrunc(Op, SDLoc(N), VT); + return DAG.getZExtOrTrunc(Op, DL, VT); } // fold (zext (truncate x)) -> (and x, mask) @@ -12780,9 +13354,9 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) && TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) { SDValue Op = N0.getOperand(0); - Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT); + Op = DAG.getZeroExtendInReg(Op, DL, MinVT); AddToWorklist(Op.getNode()); - SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT); + SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT); // Transfer the debug info; the new node is equivalent to N0. DAG.transferDbgValues(N0, ZExtOrTrunc); return ZExtOrTrunc; @@ -12790,9 +13364,9 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { } if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) { - SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT); + SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), DL, VT); AddToWorklist(Op.getNode()); - SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT); + SDValue And = DAG.getZeroExtendInReg(Op, DL, MinVT); // We may safely transfer the debug info describing the truncate node over // to the equivalent and operation. DAG.transferDbgValues(N0, And); @@ -12811,7 +13385,6 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { SDValue X = N0.getOperand(0).getOperand(0); X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT); APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits()); - SDLoc DL(N); return DAG.getNode(ISD::AND, DL, VT, X, DAG.getConstant(Mask, DL, VT)); } @@ -12836,8 +13409,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { // (and/or/xor (zextload x), (zext cst)) // Unless (and (load x) cst) will match as a zextload already and has // additional users. - if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR || - N0.getOpcode() == ISD::XOR) && + if (ISD::isBitwiseLogicOp(N0.getOpcode()) && isa<LoadSDNode>(N0.getOperand(0)) && N0.getOperand(1).getOpcode() == ISD::Constant && (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) { @@ -12865,7 +13437,6 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { LN00->getMemoryVT(), LN00->getMemOperand()); APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits()); - SDLoc DL(N); SDValue And = DAG.getNode(N0.getOpcode(), DL, VT, ExtLoad, DAG.getConstant(Mask, DL, VT)); ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND); @@ -12919,7 +13490,6 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { // that matter). Check to see that they are the same size. If so, we know // that the element size of the sext'd result matches the element size of // the compare operands. - SDLoc DL(N); if (VT.getSizeInBits() == N00VT.getSizeInBits()) { // zext(setcc) -> zext_in_reg(vsetcc) for vectors. SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0), @@ -12939,7 +13509,6 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { } // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc) - SDLoc DL(N); EVT N0VT = N0.getValueType(); EVT N00VT = N0.getOperand(0).getValueType(); if (SDValue SCC = SimplifySelectCC( @@ -12952,29 +13521,29 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { // (zext (shl (zext x), cst)) -> (shl (zext x), cst) if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) && - isa<ConstantSDNode>(N0.getOperand(1)) && - N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND && - N0.hasOneUse()) { + !TLI.isZExtFree(N0, VT)) { + SDValue ShVal = N0.getOperand(0); SDValue ShAmt = N0.getOperand(1); - if (N0.getOpcode() == ISD::SHL) { - SDValue InnerZExt = N0.getOperand(0); - // If the original shl may be shifting out bits, do not perform this - // transformation. - unsigned KnownZeroBits = InnerZExt.getValueSizeInBits() - - InnerZExt.getOperand(0).getValueSizeInBits(); - if (cast<ConstantSDNode>(ShAmt)->getAPIntValue().ugt(KnownZeroBits)) - return SDValue(); - } - - SDLoc DL(N); + if (auto *ShAmtC = dyn_cast<ConstantSDNode>(ShAmt)) { + if (ShVal.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse()) { + if (N0.getOpcode() == ISD::SHL) { + // If the original shl may be shifting out bits, do not perform this + // transformation. + // TODO: Add MaskedValueIsZero check. + unsigned KnownZeroBits = ShVal.getValueSizeInBits() - + ShVal.getOperand(0).getValueSizeInBits(); + if (ShAmtC->getAPIntValue().ugt(KnownZeroBits)) + return SDValue(); + } - // Ensure that the shift amount is wide enough for the shifted value. - if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits()) - ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt); + // Ensure that the shift amount is wide enough for the shifted value. + if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits()) + ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt); - return DAG.getNode(N0.getOpcode(), DL, VT, - DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)), - ShAmt); + return DAG.getNode(N0.getOpcode(), DL, VT, + DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ShVal), ShAmt); + } + } } if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N)) @@ -12986,7 +13555,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { if (SDValue V = widenAbs(N, DAG)) return V; - if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, Level)) return Res; return SDValue(); @@ -13011,6 +13580,14 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { N0.getOpcode() == ISD::SIGN_EXTEND) return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0)); + // fold (aext (aext_extend_vector_inreg x)) -> (aext_extend_vector_inreg x) + // fold (aext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x) + // fold (aext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x) + if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG || + N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG || + N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG) + return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0)); + // fold (aext (truncate (load x))) -> (aext (smaller load x)) // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n))) if (N0.getOpcode() == ISD::TRUNCATE) { @@ -13147,7 +13724,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { if (SDValue NewCtPop = widenCtPop(N, DAG)) return NewCtPop; - if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, Level)) return Res; return SDValue(); @@ -13305,7 +13882,7 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) { const APInt &Mask = AndC->getAPIntValue(); unsigned ActiveBits = 0; if (Mask.isMask()) { - ActiveBits = Mask.countTrailingOnes(); + ActiveBits = Mask.countr_one(); } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) { HasShiftedOffset = true; } else { @@ -13373,8 +13950,8 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) { isa<ConstantSDNode>(Mask->getOperand(1))) { const APInt& ShiftMask = Mask->getConstantOperandAPInt(1); if (ShiftMask.isMask()) { - EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(), - ShiftMask.countTrailingOnes()); + EVT MaskedVT = + EVT::getIntegerVT(*DAG.getContext(), ShiftMask.countr_one()); // If the mask is smaller, recompute the type. if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) && TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT)) @@ -13520,9 +14097,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x) // if x is small enough or if we know that x has more than 1 sign bit and the // sign_extend_inreg is extending from one of them. - if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG || - N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG || - N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) { + if (ISD::isExtVecInRegOpcode(N0.getOpcode())) { SDValue N00 = N0.getOperand(0); unsigned N00Bits = N00.getScalarValueSizeInBits(); unsigned DstElts = N0.getValueType().getVectorMinNumElements(); @@ -13543,7 +14118,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { SDValue N00 = N0.getOperand(0); if (N00.getScalarValueSizeInBits() == ExtVTBits && (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT))) - return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1); + return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00); } // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero. @@ -13690,9 +14265,7 @@ foldExtendVectorInregToExtendOfSubvector(SDNode *N, const TargetLowering &TLI, Src.getValueType().getVectorElementType(), VT.getVectorElementCount()); - assert((InregOpcode == ISD::SIGN_EXTEND_VECTOR_INREG || - InregOpcode == ISD::ZERO_EXTEND_VECTOR_INREG || - InregOpcode == ISD::ANY_EXTEND_VECTOR_INREG) && + assert(ISD::isExtVecInRegOpcode(InregOpcode) && "Expected EXTEND_VECTOR_INREG dag node in input!"); // Profitability check: our operand must be an one-use CONCAT_VECTORS. @@ -13752,11 +14325,8 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0)); // fold (truncate c1) -> c1 - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) { - SDValue C = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0); - if (C.getNode() != N) - return C; - } + if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, SDLoc(N), VT, {N0})) + return C; // fold (truncate (ext x)) -> (ext x) or (truncate x) or x if (N0.getOpcode() == ISD::ZERO_EXTEND || @@ -13860,6 +14430,9 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { if (SDValue V = foldSubToUSubSat(VT, N0.getNode())) return V; + if (SDValue ABD = foldABSToABD(N)) + return ABD; + // Attempt to pre-truncate BUILD_VECTOR sources. if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations && TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()) && @@ -14036,12 +14609,13 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { } break; case ISD::ADDE: - case ISD::ADDCARRY: + case ISD::UADDO_CARRY: // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry) - // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry) + // (trunc uaddo_carry(X, Y, Carry)) -> + // (uaddo_carry trunc(X), trunc(Y), Carry) // When the adde's carry is not used. - // We only do for addcarry before legalize operation - if (((!LegalOperations && N0.getOpcode() == ISD::ADDCARRY) || + // We only do for uaddo_carry before legalize operation + if (((!LegalOperations && N0.getOpcode() == ISD::UADDO_CARRY) || TLI.isOperationLegal(N0.getOpcode(), VT)) && N0.hasOneUse() && !N0->hasAnyUseOfValue(1)) { SDLoc DL(N); @@ -14114,18 +14688,19 @@ static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) { return DAG.getDataLayout().isBigEndian() ? 1 : 0; } -static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG, - const TargetLowering &TLI) { +SDValue DAGCombiner::foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG, + const TargetLowering &TLI) { // If this is not a bitcast to an FP type or if the target doesn't have // IEEE754-compliant FP logic, we're done. EVT VT = N->getValueType(0); - if (!VT.isFloatingPoint() || !TLI.hasBitPreservingFPLogic(VT)) + SDValue N0 = N->getOperand(0); + EVT SourceVT = N0.getValueType(); + + if (!VT.isFloatingPoint()) return SDValue(); // TODO: Handle cases where the integer constant is a different scalar // bitwidth to the FP. - SDValue N0 = N->getOperand(0); - EVT SourceVT = N0.getValueType(); if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits()) return SDValue(); @@ -14148,6 +14723,19 @@ static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG, return SDValue(); } + if (LegalOperations && !TLI.isOperationLegal(FPOpcode, VT)) + return SDValue(); + + // This needs to be the inverse of logic in foldSignChangeInBitcast. + // FIXME: I don't think looking for bitcast intrinsically makes sense, but + // removing this would require more changes. + auto IsBitCastOrFree = [&TLI, FPOpcode](SDValue Op, EVT VT) { + if (Op.getOpcode() == ISD::BITCAST && Op.getOperand(0).getValueType() == VT) + return true; + + return FPOpcode == ISD::FABS ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT); + }; + // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) -> @@ -14155,9 +14743,9 @@ static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG, SDValue LogicOp0 = N0.getOperand(0); ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true); if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask && - LogicOp0.getOpcode() == ISD::BITCAST && - LogicOp0.getOperand(0).getValueType() == VT) { - SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0.getOperand(0)); + IsBitCastOrFree(LogicOp0, VT)) { + SDValue CastOp0 = DAG.getNode(ISD::BITCAST, SDLoc(N), VT, LogicOp0); + SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, CastOp0); NumFPLogicOpsConv++; if (N0.getOpcode() == ISD::OR) return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp); @@ -14209,6 +14797,22 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { if (N0.getOpcode() == ISD::BITCAST) return DAG.getBitcast(VT, N0.getOperand(0)); + // fold (conv (logicop (conv x), (c))) -> (logicop x, (conv c)) + // iff the current bitwise logicop type isn't legal + if (ISD::isBitwiseLogicOp(N0.getOpcode()) && VT.isInteger() && + !TLI.isTypeLegal(N0.getOperand(0).getValueType())) { + auto IsFreeBitcast = [VT](SDValue V) { + return (V.getOpcode() == ISD::BITCAST && + V.getOperand(0).getValueType() == VT) || + (ISD::isBuildVectorOfConstantSDNodes(V.getNode()) && + V->hasOneUse()); + }; + if (IsFreeBitcast(N0.getOperand(0)) && IsFreeBitcast(N0.getOperand(1))) + return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, + DAG.getBitcast(VT, N0.getOperand(0)), + DAG.getBitcast(VT, N0.getOperand(1))); + } + // fold (conv (load x)) -> (load (conv*)x) // If the resultant load doesn't need a higher alignment than the original! if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() && @@ -14437,7 +15041,9 @@ SDValue DAGCombiner::visitFREEZE(SDNode *N) { N0->getNumValues() != 1 || !N0->hasOneUse()) return SDValue(); - bool AllowMultipleMaybePoisonOperands = N0.getOpcode() == ISD::BUILD_VECTOR; + bool AllowMultipleMaybePoisonOperands = N0.getOpcode() == ISD::BUILD_VECTOR || + N0.getOpcode() == ISD::BUILD_PAIR || + N0.getOpcode() == ISD::CONCAT_VECTORS; SmallSetVector<SDValue, 8> MaybePoisonOperands; for (SDValue Op : N0->ops()) { @@ -14474,6 +15080,10 @@ SDValue DAGCombiner::visitFREEZE(SDNode *N) { } } + // This node has been merged with another. + if (N->getOpcode() == ISD::DELETED_NODE) + return SDValue(N, 0); + // The whole node may have been updated, so the value we were holding // may no longer be valid. Re-fetch the operand we're `freeze`ing. N0 = N->getOperand(0); @@ -14585,21 +15195,26 @@ static bool hasNoInfs(const TargetOptions &Options, SDValue N) { } /// Try to perform FMA combining on a given FADD node. +template <class MatchContextClass> SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); - + MatchContextClass matcher(DAG, TLI, N); const TargetOptions &Options = DAG.getTarget().Options; + bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>; + // Floating-point multiply-add with intermediate rounding. - bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N)); + // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext. + // FIXME: Add VP_FMAD opcode. + bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N)); // Floating-point multiply-add without intermediate rounding. bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) && - (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); + (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) @@ -14613,6 +15228,13 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { if (!AllowFusionGlobally && !N->getFlags().hasAllowContract()) return SDValue(); + // Folding fadd (fmul x, y), (fmul x, y) -> fma x, y, (fmul x, y) is never + // beneficial. It does not reduce latency. It increases register pressure. It + // replaces an fadd with an fma which is a more complex instruction, so is + // likely to have a larger encoding, use more functional units, etc. + if (N0 == N1) + return SDValue(); + if (TLI.generateFMAsInMachineCombiner(VT, OptLevel)) return SDValue(); @@ -14621,14 +15243,13 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { bool Aggressive = TLI.enableAggressiveFMAFusion(VT); auto isFusedOp = [&](SDValue N) { - unsigned Opcode = N.getOpcode(); - return Opcode == ISD::FMA || Opcode == ISD::FMAD; + return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD); }; // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. - auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { - if (N.getOpcode() != ISD::FMUL) + auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) { + if (!matcher.match(N, ISD::FMUL)) return false; return AllowFusionGlobally || N->getFlags().hasAllowContract(); }; @@ -14641,15 +15262,15 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { // fold (fadd (fmul x, y), z) -> (fma x, y, z) if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), - N0.getOperand(1), N1); + return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), + N0.getOperand(1), N1); } // fold (fadd x, (fmul y, z)) -> (fma y, z, x) // Note: Commutes FADD operands. if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), - N1.getOperand(1), N0); + return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), + N1.getOperand(1), N0); } // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E) @@ -14673,10 +15294,10 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDValue TmpFMA = FMA; while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) { SDValue FMul = TmpFMA->getOperand(2); - if (FMul.getOpcode() == ISD::FMUL && FMul.hasOneUse()) { + if (matcher.match(FMul, ISD::FMUL) && FMul.hasOneUse()) { SDValue C = FMul.getOperand(0); SDValue D = FMul.getOperand(1); - SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E); + SDValue CDE = matcher.getNode(PreferredFusedOpcode, SL, VT, C, D, E); DAG.ReplaceAllUsesOfValueWith(FMul, CDE); // Replacing the inner FMul could cause the outer FMA to be simplified // away. @@ -14690,29 +15311,29 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { // Look through FP_EXTEND nodes to do more combining. // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) - if (N0.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N0, ISD::FP_EXTEND)) { SDValue N00 = N0.getOperand(0); if (isContractableFMUL(N00) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), - N1); + return matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1); } } // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) // Note: Commutes FADD operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N1, ISD::FP_EXTEND)) { SDValue N10 = N1.getOperand(0); if (isContractableFMUL(N10) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N10.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), - N0); + return matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0); } } @@ -14722,15 +15343,15 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { // -> (fma x, y, (fma (fpext u), (fpext v), z)) auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), - Z)); + return matcher.getNode( + PreferredFusedOpcode, SL, VT, X, Y, + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z)); }; if (isFusedOp(N0)) { SDValue N02 = N0.getOperand(2); - if (N02.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N02, ISD::FP_EXTEND)) { SDValue N020 = N02.getOperand(0); if (isContractableFMUL(N020) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, @@ -14749,12 +15370,13 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { // interesting for all targets, especially GPUs. auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) { - return DAG.getNode( - PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, X), - DAG.getNode(ISD::FP_EXTEND, SL, VT, Y), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), Z)); + return matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, X), + matcher.getNode(ISD::FP_EXTEND, SL, VT, Y), + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z)); }; if (N0.getOpcode() == ISD::FP_EXTEND) { SDValue N00 = N0.getOperand(0); @@ -14810,20 +15432,26 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { } /// Try to perform FMA combining on a given FSUB node. +template <class MatchContextClass> SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); - + MatchContextClass matcher(DAG, TLI, N); const TargetOptions &Options = DAG.getTarget().Options; + + bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>; + // Floating-point multiply-add with intermediate rounding. - bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N)); + // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext. + // FIXME: Add VP_FMAD opcode. + bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N)); // Floating-point multiply-add without intermediate rounding. bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) && - (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); + (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) @@ -14847,8 +15475,8 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. - auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { - if (N.getOpcode() != ISD::FMUL) + auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) { + if (!matcher.match(N, ISD::FMUL)) return false; return AllowFusionGlobally || N->getFlags().hasAllowContract(); }; @@ -14856,8 +15484,9 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z)) auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) { if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0), - XY.getOperand(1), DAG.getNode(ISD::FNEG, SL, VT, Z)); + return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0), + XY.getOperand(1), + matcher.getNode(ISD::FNEG, SL, VT, Z)); } return SDValue(); }; @@ -14866,9 +15495,10 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { // Note: Commutes FSUB operands. auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) { if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)), - YZ.getOperand(1), X); + return matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)), + YZ.getOperand(1), X); } return SDValue(); }; @@ -14893,44 +15523,46 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { } // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z)) - if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) && + if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(0)) && (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) { SDValue N00 = N0.getOperand(0).getOperand(0); SDValue N01 = N0.getOperand(0).getOperand(1); - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, N00), N01, - DAG.getNode(ISD::FNEG, SL, VT, N1)); + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FNEG, SL, VT, N00), N01, + matcher.getNode(ISD::FNEG, SL, VT, N1)); } // Look through FP_EXTEND nodes to do more combining. // fold (fsub (fpext (fmul x, y)), z) // -> (fma (fpext x), (fpext y), (fneg z)) - if (N0.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N0, ISD::FP_EXTEND)) { SDValue N00 = N0.getOperand(0); if (isContractableFMUL(N00) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), - DAG.getNode(ISD::FNEG, SL, VT, N1)); + return matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), + matcher.getNode(ISD::FNEG, SL, VT, N1)); } } // fold (fsub x, (fpext (fmul y, z))) // -> (fma (fneg (fpext y)), (fpext z), x) // Note: Commutes FSUB operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N1, ISD::FP_EXTEND)) { SDValue N10 = N1.getOperand(0); if (isContractableFMUL(N10) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N10.getValueType())) { - return DAG.getNode( + return matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0); + matcher.getNode( + ISD::FNEG, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0); } } @@ -14940,19 +15572,20 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent // from implementing the canonicalization in visitFSUB. - if (N0.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N0, ISD::FP_EXTEND)) { SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::FNEG) { + if (matcher.match(N00, ISD::FNEG)) { SDValue N000 = N00.getOperand(0); if (isContractableFMUL(N000) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode( + return matcher.getNode( ISD::FNEG, SL, VT, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), - N1)); + matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), + N1)); } } } @@ -14963,24 +15596,25 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent // from implementing the canonicalization in visitFSUB. - if (N0.getOpcode() == ISD::FNEG) { + if (matcher.match(N0, ISD::FNEG)) { SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N00, ISD::FP_EXTEND)) { SDValue N000 = N00.getOperand(0); if (isContractableFMUL(N000) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N000.getValueType())) { - return DAG.getNode( + return matcher.getNode( ISD::FNEG, SL, VT, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), - N1)); + matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), + N1)); } } } - auto isReassociable = [Options](SDNode *N) { + auto isReassociable = [&Options](SDNode *N) { return Options.UnsafeFPMath || N->getFlags().hasAllowReassociation(); }; @@ -14990,8 +15624,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { }; auto isFusedOp = [&](SDValue N) { - unsigned Opcode = N.getOpcode(); - return Opcode == ISD::FMA || Opcode == ISD::FMAD; + return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD); }; // More folding opportunities when target permits. @@ -15002,12 +15635,12 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { if (CanFuse && isFusedOp(N0) && isContractableAndReassociableFMUL(N0.getOperand(2)) && N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), - N0.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - N0.getOperand(2).getOperand(0), - N0.getOperand(2).getOperand(1), - DAG.getNode(ISD::FNEG, SL, VT, N1))); + return matcher.getNode( + PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), + matcher.getNode(PreferredFusedOpcode, SL, VT, + N0.getOperand(2).getOperand(0), + N0.getOperand(2).getOperand(1), + matcher.getNode(ISD::FNEG, SL, VT, N1))); } // fold (fsub x, (fma y, z, (fmul u, v))) @@ -15017,29 +15650,30 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { N1->hasOneUse() && NoSignedZero) { SDValue N20 = N1.getOperand(2).getOperand(0); SDValue N21 = N1.getOperand(2).getOperand(1); - return DAG.getNode( + return matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, N20), N21, N0)); + matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), + N1.getOperand(1), + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0)); } // fold (fsub (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y (fma (fpext u), (fpext v), (fneg z))) if (isFusedOp(N0) && N0->hasOneUse()) { SDValue N02 = N0.getOperand(2); - if (N02.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N02, ISD::FP_EXTEND)) { SDValue N020 = N02.getOperand(0); if (isContractableAndReassociableFMUL(N020) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N020.getValueType())) { - return DAG.getNode( + return matcher.getNode( PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - DAG.getNode( + matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)), - DAG.getNode(ISD::FNEG, SL, VT, N1))); + matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)), + matcher.getNode(ISD::FNEG, SL, VT, N1))); } } } @@ -15050,29 +15684,29 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { // FIXME: This turns two single-precision and one double-precision // operation into two double-precision operations, which might not be // interesting for all targets, especially GPUs. - if (N0.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N0, ISD::FP_EXTEND)) { SDValue N00 = N0.getOperand(0); if (isFusedOp(N00)) { SDValue N002 = N00.getOperand(2); if (isContractableAndReassociableFMUL(N002) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode( + return matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), - DAG.getNode( + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), + matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)), - DAG.getNode(ISD::FNEG, SL, VT, N1))); + matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)), + matcher.getNode(ISD::FNEG, SL, VT, N1))); } } } // fold (fsub x, (fma y, z, (fpext (fmul u, v)))) // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x)) - if (isFusedOp(N1) && N1.getOperand(2).getOpcode() == ISD::FP_EXTEND && + if (isFusedOp(N1) && matcher.match(N1.getOperand(2), ISD::FP_EXTEND) && N1->hasOneUse()) { SDValue N120 = N1.getOperand(2).getOperand(0); if (isContractableAndReassociableFMUL(N120) && @@ -15080,13 +15714,15 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { N120.getValueType())) { SDValue N1200 = N120.getOperand(0); SDValue N1201 = N120.getOperand(1); - return DAG.getNode( + return matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N1200)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0)); + matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), + N1.getOperand(1), + matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FNEG, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0)); } } @@ -15096,7 +15732,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { // FIXME: This turns two single-precision and one double-precision // operation into two double-precision operations, which might not be // interesting for all targets, especially GPUs. - if (N1.getOpcode() == ISD::FP_EXTEND && isFusedOp(N1.getOperand(0))) { + if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(0))) { SDValue CvtSrc = N1.getOperand(0); SDValue N100 = CvtSrc.getOperand(0); SDValue N101 = CvtSrc.getOperand(1); @@ -15106,15 +15742,16 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { CvtSrc.getValueType())) { SDValue N1020 = N102.getOperand(0); SDValue N1021 = N102.getOperand(1); - return DAG.getNode( + return matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N100)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N101), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N1020)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0)); + matcher.getNode(ISD::FNEG, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N101), + matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FNEG, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0)); } } } @@ -15217,6 +15854,17 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitVP_FADD(SDNode *N) { + SelectionDAG::FlagInserter FlagsInserter(DAG, N); + + // FADD -> FMA combines: + if (SDValue Fused = visitFADDForFMACombine<VPMatchContext>(N)) { + AddToWorklist(Fused.getNode()); + return Fused; + } + return SDValue(); +} + SDValue DAGCombiner::visitFADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -15394,10 +16042,15 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { DAG.getConstantFP(4.0, DL, VT)); } } + + // Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y)) + if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FADD, ISD::FADD, DL, + VT, N0, N1, Flags)) + return SD; } // enable-unsafe-fp-math // FADD -> FMA combines: - if (SDValue Fused = visitFADDForFMACombine(N)) { + if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) { AddToWorklist(Fused.getNode()); return Fused; } @@ -15507,7 +16160,7 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1); // FSUB -> FMA combines: - if (SDValue Fused = visitFSUBForFMACombine(N)) { + if (SDValue Fused = visitFSUBForFMACombine<EmptyMatchContext>(N)) { AddToWorklist(Fused.getNode()); return Fused; } @@ -15568,6 +16221,11 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1); return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts); } + + // Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y)) + if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FMUL, ISD::FMUL, DL, + VT, N0, N1, Flags)) + return SD; } // fold (fmul X, 2.0) -> (fadd X, X) @@ -15653,7 +16311,7 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitFMA(SDNode *N) { +template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); SDValue N2 = N->getOperand(2); @@ -15664,6 +16322,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { const TargetOptions &Options = DAG.getTarget().Options; // FMA nodes have flags that propagate to the created nodes. SelectionDAG::FlagInserter FlagsInserter(DAG, N); + MatchContextClass matcher(DAG, TLI, N); bool CanReassociate = Options.UnsafeFPMath || N->getFlags().hasAllowReassociation(); @@ -15672,7 +16331,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { if (isa<ConstantFPSDNode>(N0) && isa<ConstantFPSDNode>(N1) && isa<ConstantFPSDNode>(N2)) { - return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2); + return matcher.getNode(ISD::FMA, DL, VT, N0, N1, N2); } // (-N0 * -N1) + N2 --> (N0 * N1) + N2 @@ -15688,7 +16347,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1); if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper || CostN1 == TargetLowering::NegatibleCost::Cheaper)) - return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2); + return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2); } // FIXME: use fast math flags instead of Options.UnsafeFPMath @@ -15699,70 +16358,74 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { return N2; } + // FIXME: Support splat of constant. if (N0CFP && N0CFP->isExactlyValue(1.0)) - return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2); + return matcher.getNode(ISD::FADD, SDLoc(N), VT, N1, N2); if (N1CFP && N1CFP->isExactlyValue(1.0)) - return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2); + return matcher.getNode(ISD::FADD, SDLoc(N), VT, N0, N2); // Canonicalize (fma c, x, y) -> (fma x, c, y) if (DAG.isConstantFPBuildVectorOrConstantFP(N0) && !DAG.isConstantFPBuildVectorOrConstantFP(N1)) - return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2); + return matcher.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2); if (CanReassociate) { // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2) - if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) && + if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(0) && DAG.isConstantFPBuildVectorOrConstantFP(N1) && DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) { - return DAG.getNode(ISD::FMUL, DL, VT, N0, - DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1))); + return matcher.getNode( + ISD::FMUL, DL, VT, N0, + matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1))); } // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y) - if (N0.getOpcode() == ISD::FMUL && + if (matcher.match(N0, ISD::FMUL) && DAG.isConstantFPBuildVectorOrConstantFP(N1) && DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) { - return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0), - DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), - N2); + return matcher.getNode( + ISD::FMA, DL, VT, N0.getOperand(0), + matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), N2); } } // (fma x, -1, y) -> (fadd (fneg x), y) + // FIXME: Support splat of constant. if (N1CFP) { if (N1CFP->isExactlyValue(1.0)) - return DAG.getNode(ISD::FADD, DL, VT, N0, N2); + return matcher.getNode(ISD::FADD, DL, VT, N0, N2); if (N1CFP->isExactlyValue(-1.0) && (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) { - SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0); + SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0); AddToWorklist(RHSNeg.getNode()); - return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg); + return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg); } // fma (fneg x), K, y -> fma x -K, y - if (N0.getOpcode() == ISD::FNEG && + if (matcher.match(N0, ISD::FNEG) && (TLI.isOperationLegal(ISD::ConstantFP, VT) || - (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, - ForCodeSize)))) { - return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0), - DAG.getNode(ISD::FNEG, DL, VT, N1), N2); + (N1.hasOneUse() && + !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) { + return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0), + matcher.getNode(ISD::FNEG, DL, VT, N1), N2); } } + // FIXME: Support splat of constant. if (CanReassociate) { // (fma x, c, x) -> (fmul x, (c+1)) if (N1CFP && N0 == N2) { - return DAG.getNode( - ISD::FMUL, DL, VT, N0, - DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(1.0, DL, VT))); + return matcher.getNode(ISD::FMUL, DL, VT, N0, + matcher.getNode(ISD::FADD, DL, VT, N1, + DAG.getConstantFP(1.0, DL, VT))); } // (fma x, c, (fneg x)) -> (fmul x, (c-1)) - if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) { - return DAG.getNode( - ISD::FMUL, DL, VT, N0, - DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(-1.0, DL, VT))); + if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(0) == N0) { + return matcher.getNode(ISD::FMUL, DL, VT, N0, + matcher.getNode(ISD::FADD, DL, VT, N1, + DAG.getConstantFP(-1.0, DL, VT))); } } @@ -15771,7 +16434,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { if (!TLI.isFNegFree(VT)) if (SDValue Neg = TLI.getCheaperNegatedExpression( SDValue(N, 0), DAG, LegalOperations, ForCodeSize)) - return DAG.getNode(ISD::FNEG, DL, VT, Neg); + return matcher.getNode(ISD::FNEG, DL, VT, Neg); return SDValue(); } @@ -16043,27 +16706,30 @@ SDValue DAGCombiner::visitFSQRT(SDNode *N) { /// copysign(x, fp_extend(y)) -> copysign(x, y) /// copysign(x, fp_round(y)) -> copysign(x, y) -static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) { - SDValue N1 = N->getOperand(1); - if ((N1.getOpcode() == ISD::FP_EXTEND || - N1.getOpcode() == ISD::FP_ROUND)) { - EVT N1VT = N1->getValueType(0); - EVT N1Op0VT = N1->getOperand(0).getValueType(); +/// Operands to the functions are the type of X and Y respectively. +static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy, EVT YTy) { + // Always fold no-op FP casts. + if (XTy == YTy) + return true; - // Always fold no-op FP casts. - if (N1VT == N1Op0VT) - return true; + // Do not optimize out type conversion of f128 type yet. + // For some targets like x86_64, configuration is changed to keep one f128 + // value in one SSE register, but instruction selection cannot handle + // FCOPYSIGN on SSE registers yet. + if (YTy == MVT::f128) + return false; - // Do not optimize out type conversion of f128 type yet. - // For some targets like x86_64, configuration is changed to keep one f128 - // value in one SSE register, but instruction selection cannot handle - // FCOPYSIGN on SSE registers yet. - if (N1Op0VT == MVT::f128) - return false; + return !YTy.isVector() || EnableVectorFCopySignExtendRound; +} - return !N1Op0VT.isVector() || EnableVectorFCopySignExtendRound; - } - return false; +static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) { + SDValue N1 = N->getOperand(1); + if (N1.getOpcode() != ISD::FP_EXTEND && + N1.getOpcode() != ISD::FP_ROUND) + return false; + EVT N1VT = N1->getValueType(0); + EVT N1Op0VT = N1->getOperand(0).getValueType(); + return CanCombineFCOPYSIGN_EXTEND_ROUND(N1VT, N1Op0VT); } SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) { @@ -16399,6 +17065,10 @@ SDValue DAGCombiner::visitFP_ROUND(SDNode *N) { const bool NIsTrunc = N->getConstantOperandVal(1) == 1; const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1; + // Avoid folding legal fp_rounds into non-legal ones. + if (!hasOperation(ISD::FP_ROUND, VT)) + return SDValue(); + // Skip this folding if it results in an fp_round from f80 to f16. // // f80 to f16 always generates an expensive (and as yet, unimplemented) @@ -16423,7 +17093,13 @@ SDValue DAGCombiner::visitFP_ROUND(SDNode *N) { } // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y) - if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse()) { + // Note: From a legality perspective, this is a two step transform. First, + // we duplicate the fp_round to the arguments of the copysign, then we + // eliminate the fp_round on Y. The second step requires an additional + // predicate to match the implementation above. + if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() && + CanCombineFCOPYSIGN_EXTEND_ROUND(VT, + N0.getValueType())) { SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT, N0.getOperand(0), N1); AddToWorklist(Tmp.getNode()); @@ -16529,6 +17205,15 @@ SDValue DAGCombiner::visitFTRUNC(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitFFREXP(SDNode *N) { + SDValue N0 = N->getOperand(0); + + // fold (ffrexp c1) -> ffrexp(c1) + if (DAG.isConstantFPBuildVectorOrConstantFP(N0)) + return DAG.getNode(ISD::FFREXP, SDLoc(N), N->getVTList(), N0); + return SDValue(); +} + SDValue DAGCombiner::visitFFLOOR(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); @@ -16618,6 +17303,13 @@ SDValue DAGCombiner::visitFMinMax(SDNode *N) { } } + if (SDValue SD = reassociateReduction( + PropagatesNaN + ? (IsMin ? ISD::VECREDUCE_FMINIMUM : ISD::VECREDUCE_FMAXIMUM) + : (IsMin ? ISD::VECREDUCE_FMIN : ISD::VECREDUCE_FMAX), + Opc, SDLoc(N), VT, N0, N1, Flags)) + return SD; + return SDValue(); } @@ -16656,6 +17348,55 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) { N1->getOperand(0), N2); } + // Variant of the previous fold where there is a SETCC in between: + // BRCOND(SETCC(FREEZE(X), CONST, Cond)) + // => + // BRCOND(FREEZE(SETCC(X, CONST, Cond))) + // => + // BRCOND(SETCC(X, CONST, Cond)) + // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond) + // isn't equivalent to true or false. + // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to + // FREEZE(SETCC(X, -128, SETULT)) because X can be poison. + if (N1->getOpcode() == ISD::SETCC && N1.hasOneUse()) { + SDValue S0 = N1->getOperand(0), S1 = N1->getOperand(1); + ISD::CondCode Cond = cast<CondCodeSDNode>(N1->getOperand(2))->get(); + ConstantSDNode *S0C = dyn_cast<ConstantSDNode>(S0); + ConstantSDNode *S1C = dyn_cast<ConstantSDNode>(S1); + bool Updated = false; + + // Is 'X Cond C' always true or false? + auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) { + bool False = (Cond == ISD::SETULT && C->isZero()) || + (Cond == ISD::SETLT && C->isMinSignedValue()) || + (Cond == ISD::SETUGT && C->isAllOnes()) || + (Cond == ISD::SETGT && C->isMaxSignedValue()); + bool True = (Cond == ISD::SETULE && C->isAllOnes()) || + (Cond == ISD::SETLE && C->isMaxSignedValue()) || + (Cond == ISD::SETUGE && C->isZero()) || + (Cond == ISD::SETGE && C->isMinSignedValue()); + return True || False; + }; + + if (S0->getOpcode() == ISD::FREEZE && S0.hasOneUse() && S1C) { + if (!IsAlwaysTrueOrFalse(Cond, S1C)) { + S0 = S0->getOperand(0); + Updated = true; + } + } + if (S1->getOpcode() == ISD::FREEZE && S1.hasOneUse() && S0C) { + if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Cond), S0C)) { + S1 = S1->getOperand(0); + Updated = true; + } + } + + if (Updated) + return DAG.getNode( + ISD::BRCOND, SDLoc(N), MVT::Other, Chain, + DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2); + } + // If N is a constant we could fold this into a fallthrough or unconditional // branch. However that doesn't happen very often in normal code, because // Instcombine/SimplifyCFG should have handled the available opportunities. @@ -17288,11 +18029,53 @@ bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) { return false; } +StoreSDNode *DAGCombiner::getUniqueStoreFeeding(LoadSDNode *LD, + int64_t &Offset) { + SDValue Chain = LD->getOperand(0); + + // Look through CALLSEQ_START. + if (Chain.getOpcode() == ISD::CALLSEQ_START) + Chain = Chain->getOperand(0); + + StoreSDNode *ST = nullptr; + SmallVector<SDValue, 8> Aliases; + if (Chain.getOpcode() == ISD::TokenFactor) { + // Look for unique store within the TokenFactor. + for (SDValue Op : Chain->ops()) { + StoreSDNode *Store = dyn_cast<StoreSDNode>(Op.getNode()); + if (!Store) + continue; + BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG); + BaseIndexOffset BasePtrST = BaseIndexOffset::match(Store, DAG); + if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset)) + continue; + // Make sure the store is not aliased with any nodes in TokenFactor. + GatherAllAliases(Store, Chain, Aliases); + if (Aliases.empty() || + (Aliases.size() == 1 && Aliases.front().getNode() == Store)) + ST = Store; + break; + } + } else { + StoreSDNode *Store = dyn_cast<StoreSDNode>(Chain.getNode()); + if (Store) { + BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG); + BaseIndexOffset BasePtrST = BaseIndexOffset::match(Store, DAG); + if (BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset)) + ST = Store; + } + } + + return ST; +} + SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) { if (OptLevel == CodeGenOpt::None || !LD->isSimple()) return SDValue(); SDValue Chain = LD->getOperand(0); - StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode()); + int64_t Offset; + + StoreSDNode *ST = getUniqueStoreFeeding(LD, Offset); // TODO: Relax this restriction for unordered atomics (see D66309) if (!ST || !ST->isSimple() || ST->getAddressSpace() != LD->getAddressSpace()) return SDValue(); @@ -17309,8 +18092,8 @@ SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) { // 2. The store is scalable and the load is fixed width. We could // potentially support a limited number of cases here, but there has been // no cost-benefit analysis to prove it's worth it. - bool LdStScalable = LDMemType.isScalableVector(); - if (LdStScalable != STMemType.isScalableVector()) + bool LdStScalable = LDMemType.isScalableVT(); + if (LdStScalable != STMemType.isScalableVT()) return SDValue(); // If we are dealing with scalable vectors on a big endian platform the @@ -17320,12 +18103,6 @@ SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) { if (LdStScalable && DAG.getDataLayout().isBigEndian()) return SDValue(); - BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG); - BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG); - int64_t Offset; - if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset)) - return SDValue(); - // Normalize for Endianness. After this Offset=0 will denote that the least // significant bit in the loaded value maps to the least significant bit in // the stored value). With Offset=n (for n > 0) the loaded value starts at the @@ -17682,7 +18459,7 @@ struct LoadedSlice { /// Get the size of the slice to be loaded in bytes. unsigned getLoadedSize() const { - unsigned SliceSize = getUsedBits().countPopulation(); + unsigned SliceSize = getUsedBits().popcount(); assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte."); return SliceSize / 8; } @@ -17867,9 +18644,9 @@ static bool areUsedBitsDense(const APInt &UsedBits) { return true; // Get rid of the unused bits on the right. - APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countTrailingZeros()); + APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countr_zero()); // Get rid of the unused bits on the left. - if (NarrowedUsedBits.countLeadingZeros()) + if (NarrowedUsedBits.countl_zero()) NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits()); // Check that the chunk of bits is completely used. return NarrowedUsedBits.isAllOnes(); @@ -18125,14 +18902,14 @@ CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) { // 0 and the bits being kept are 1. Use getSExtValue so that leading bits // follow the sign bit for uniformity. uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue(); - unsigned NotMaskLZ = countLeadingZeros(NotMask); + unsigned NotMaskLZ = llvm::countl_zero(NotMask); if (NotMaskLZ & 7) return Result; // Must be multiple of a byte. - unsigned NotMaskTZ = countTrailingZeros(NotMask); + unsigned NotMaskTZ = llvm::countr_zero(NotMask); if (NotMaskTZ & 7) return Result; // Must be multiple of a byte. if (NotMaskLZ == 64) return Result; // All zero mask. // See if we have a continuous run of bits. If so, we have 0*1+0* - if (countTrailingOnes(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64) + if (llvm::countr_one(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64) return Result; // Adjust NotMaskLZ down to be from the actual size of the int instead of i64. @@ -18199,6 +18976,11 @@ ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo, UseTruncStore = true; else return SDValue(); + + // Can't do this for indexed stores. + if (St->isIndexed()) + return SDValue(); + // Check that the target doesn't think this is a bad idea. if (St->getMemOperand() && !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT, @@ -18309,8 +19091,8 @@ SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) { Imm ^= APInt::getAllOnes(BitWidth); if (Imm == 0 || Imm.isAllOnes()) return SDValue(); - unsigned ShAmt = Imm.countTrailingZeros(); - unsigned MSB = BitWidth - Imm.countLeadingZeros() - 1; + unsigned ShAmt = Imm.countr_zero(); + unsigned MSB = BitWidth - Imm.countl_zero() - 1; unsigned NewBW = NextPowerOf2(MSB - ShAmt); EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW); // The narrowing should be profitable, the load/store operation should be @@ -18527,6 +19309,30 @@ SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes, return DAG.getTokenFactor(StoreDL, Chains); } +bool DAGCombiner::hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes) { + const Value *UnderlyingObj = nullptr; + for (const auto &MemOp : StoreNodes) { + const MachineMemOperand *MMO = MemOp.MemNode->getMemOperand(); + // Pseudo value like stack frame has its own frame index and size, should + // not use the first store's frame index for other frames. + if (MMO->getPseudoValue()) + return false; + + if (!MMO->getValue()) + return false; + + const Value *Obj = getUnderlyingObject(MMO->getValue()); + + if (UnderlyingObj && UnderlyingObj != Obj) + return false; + + if (!UnderlyingObj) + UnderlyingObj = Obj; + } + + return true; +} + bool DAGCombiner::mergeStoresOfConstantsOrVecElts( SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores, bool IsConstantSrc, bool UseVector, bool UseTrunc) { @@ -18678,13 +19484,21 @@ bool DAGCombiner::mergeStoresOfConstantsOrVecElts( LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores); + bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes); // make sure we use trunc store if it's necessary to be legal. + // When generate the new widen store, if the first store's pointer info can + // not be reused, discard the pointer info except the address space because + // now the widen store can not be represented by the original pointer info + // which is for the narrow memory object. SDValue NewStore; if (!UseTrunc) { - NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(), - FirstInChain->getPointerInfo(), - FirstInChain->getAlign(), *Flags, AAInfo); + NewStore = DAG.getStore( + NewChain, DL, StoredVal, FirstInChain->getBasePtr(), + CanReusePtrInfo + ? FirstInChain->getPointerInfo() + : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()), + FirstInChain->getAlign(), *Flags, AAInfo); } else { // Must be realized as a trunc store EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType()); @@ -18695,8 +19509,11 @@ bool DAGCombiner::mergeStoresOfConstantsOrVecElts( LegalizedStoredValTy); NewStore = DAG.getTruncStore( NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(), - FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/, - FirstInChain->getAlign(), *Flags, AAInfo); + CanReusePtrInfo + ? FirstInChain->getPointerInfo() + : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()), + StoredVal.getValueType() /*TVT*/, FirstInChain->getAlign(), *Flags, + AAInfo); } // Replace all merged stores with the new store. @@ -18749,6 +19566,8 @@ void DAGCombiner::getStoreMergeCandidates( // Don't mix temporal stores with non-temporal stores. if (St->isNonTemporal() != Other->isNonTemporal()) return false; + if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(*St, *Other)) + return false; SDValue OtherBC = peekThroughBitcasts(Other->getValue()); // Allow merging constants of different types as integers. bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT()) @@ -18774,6 +19593,9 @@ void DAGCombiner::getStoreMergeCandidates( // Don't mix temporal loads with non-temporal loads. if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal()) return false; + if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(*cast<LoadSDNode>(Val), + *OtherLd)) + return false; if (!(LBasePtr.equalBaseIndex(LPtr, DAG))) return false; break; @@ -19042,11 +19864,9 @@ bool DAGCombiner::tryStoreMergeOfConstants( } } - // We only use vectors if the constant is known to be zero or the - // target allows it and the function is not marked with the - // noimplicitfloat attribute. - if ((!NonZero || - TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) && + // We only use vectors if the target allows it and the function is not + // marked with the noimplicitfloat attribute. + if (TLI.storeOfVectorConstantIsCheap(!NonZero, MemVT, i + 1, FirstStoreAS) && AllowVectors) { // Find a legal type for the vector store. unsigned Elts = (i + 1) * NumMemElts; @@ -19389,6 +20209,7 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes, // using the first's chain is acceptable. SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem); + bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes); AddToWorklist(NewStoreChain.getNode()); MachineMemOperand::Flags LdMMOFlags = @@ -19397,10 +20218,14 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes, if (IsNonTemporalLoad) LdMMOFlags |= MachineMemOperand::MONonTemporal; + LdMMOFlags |= TLI.getTargetMMOFlags(*FirstLoad); + MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore ? MachineMemOperand::MONonTemporal : MachineMemOperand::MONone; + StMMOFlags |= TLI.getTargetMMOFlags(*StoreNodes[0].MemNode); + SDValue NewLoad, NewStore; if (UseVectorTy || !DoIntegerTruncate) { NewLoad = DAG.getLoad( @@ -19418,7 +20243,9 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes, } NewStore = DAG.getStore( NewStoreChain, StoreDL, StoreOp, FirstInChain->getBasePtr(), - FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags); + CanReusePtrInfo ? FirstInChain->getPointerInfo() + : MachinePointerInfo(FirstStoreAS), + FirstStoreAlign, StMMOFlags); } else { // This must be the truncstore/extload case EVT ExtendedTy = TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT); @@ -19428,8 +20255,10 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes, FirstLoadAlign, LdMMOFlags); NewStore = DAG.getTruncStore( NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(), - FirstInChain->getPointerInfo(), JointMemOpVT, - FirstInChain->getAlign(), FirstInChain->getMemOperand()->getFlags()); + CanReusePtrInfo ? FirstInChain->getPointerInfo() + : MachinePointerInfo(FirstStoreAS), + JointMemOpVT, FirstInChain->getAlign(), + FirstInChain->getMemOperand()->getFlags()); } // Transfer chain users from old loads to the new load. @@ -19465,7 +20294,7 @@ bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) { // store since we know <vscale x 16 x i8> is exactly twice as large as // <vscale x 8 x i8>). Until then, bail out for scalable vectors. EVT MemVT = St->getMemoryVT(); - if (MemVT.isScalableVector()) + if (MemVT.isScalableVT()) return false; if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits) return false; @@ -19647,6 +20476,62 @@ SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) { } } +// (store (insert_vector_elt (load p), x, i), p) -> (store x, p+offset) +// +// If a store of a load with an element inserted into it has no other +// uses in between the chain, then we can consider the vector store +// dead and replace it with just the single scalar element store. +SDValue DAGCombiner::replaceStoreOfInsertLoad(StoreSDNode *ST) { + SDLoc DL(ST); + SDValue Value = ST->getValue(); + SDValue Ptr = ST->getBasePtr(); + SDValue Chain = ST->getChain(); + if (Value.getOpcode() != ISD::INSERT_VECTOR_ELT || !Value.hasOneUse()) + return SDValue(); + + SDValue Elt = Value.getOperand(1); + SDValue Idx = Value.getOperand(2); + + // If the element isn't byte sized then we can't compute an offset + EVT EltVT = Elt.getValueType(); + if (!EltVT.isByteSized()) + return SDValue(); + + auto *Ld = dyn_cast<LoadSDNode>(Value.getOperand(0)); + if (!Ld || Ld->getBasePtr() != Ptr || + ST->getMemoryVT() != Ld->getMemoryVT() || !ST->isSimple() || + !ISD::isNormalStore(ST) || + Ld->getAddressSpace() != ST->getAddressSpace() || + !Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) + return SDValue(); + + unsigned IsFast; + if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), + Elt.getValueType(), ST->getAddressSpace(), + ST->getAlign(), ST->getMemOperand()->getFlags(), + &IsFast) || + !IsFast) + return SDValue(); + EVT PtrVT = Ptr.getValueType(); + + SDValue Offset = + DAG.getNode(ISD::MUL, DL, PtrVT, Idx, + DAG.getConstant(EltVT.getSizeInBits() / 8, DL, PtrVT)); + SDValue NewPtr = DAG.getNode(ISD::ADD, DL, PtrVT, Ptr, Offset); + MachinePointerInfo PointerInfo(ST->getAddressSpace()); + + // If the offset is a known constant then try to recover the pointer + // info + if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx)) { + unsigned COffset = CIdx->getSExtValue() * EltVT.getSizeInBits() / 8; + NewPtr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(COffset), DL); + PointerInfo = ST->getPointerInfo().getWithOffset(COffset); + } + + return DAG.getStore(Chain, DL, Elt, NewPtr, PointerInfo, ST->getAlign(), + ST->getMemOperand()->getFlags()); +} + SDValue DAGCombiner::visitSTORE(SDNode *N) { StoreSDNode *ST = cast<StoreSDNode>(N); SDValue Chain = ST->getChain(); @@ -19768,9 +20653,13 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { } // If this is a load followed by a store to the same location, then the store - // is dead/noop. + // is dead/noop. Peek through any truncates if canCombineTruncStore failed. + // TODO: Add big-endian truncate support with test coverage. // TODO: Can relax for unordered atomics (see D66309) - if (LoadSDNode *Ld = dyn_cast<LoadSDNode>(Value)) { + SDValue TruncVal = DAG.getDataLayout().isLittleEndian() + ? peekThroughTruncates(Value) + : Value; + if (auto *Ld = dyn_cast<LoadSDNode>(TruncVal)) { if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() && ST->isUnindexed() && ST->isSimple() && Ld->getAddressSpace() == ST->getAddressSpace() && @@ -19782,6 +20671,10 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { } } + // Try scalarizing vector stores of loads where we only change one element + if (SDValue NewST = replaceStoreOfInsertLoad(ST)) + return NewST; + // TODO: Can relax for unordered atomics (see D66309) if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) { if (ST->isUnindexed() && ST->isSimple() && @@ -19796,22 +20689,32 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() && !ST1->getBasePtr().isUndef() && - // BaseIndexOffset and the code below requires knowing the size - // of a vector, so bail out if MemoryVT is scalable. - !ST->getMemoryVT().isScalableVector() && - !ST1->getMemoryVT().isScalableVector() && ST->getAddressSpace() == ST1->getAddressSpace()) { - const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG); - const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG); - unsigned STBitSize = ST->getMemoryVT().getFixedSizeInBits(); - unsigned ChainBitSize = ST1->getMemoryVT().getFixedSizeInBits(); - // If this is a store who's preceding store to a subset of the current - // location and no one other node is chained to that store we can - // effectively drop the store. Do not remove stores to undef as they may - // be used as data sinks. - if (STBase.contains(DAG, STBitSize, ChainBase, ChainBitSize)) { - CombineTo(ST1, ST1->getChain()); - return SDValue(); + // If we consider two stores and one smaller in size is a scalable + // vector type and another one a bigger size store with a fixed type, + // then we could not allow the scalable store removal because we don't + // know its final size in the end. + if (ST->getMemoryVT().isScalableVector() || + ST1->getMemoryVT().isScalableVector()) { + if (ST1->getBasePtr() == Ptr && + TypeSize::isKnownLE(ST1->getMemoryVT().getStoreSize(), + ST->getMemoryVT().getStoreSize())) { + CombineTo(ST1, ST1->getChain()); + return SDValue(); + } + } else { + const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG); + const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG); + // If this is a store who's preceding store to a subset of the current + // location and no one other node is chained to that store we can + // effectively drop the store. Do not remove stores to undef as they + // may be used as data sinks. + if (STBase.contains(DAG, ST->getMemoryVT().getFixedSizeInBits(), + ChainBase, + ST1->getMemoryVT().getFixedSizeInBits())) { + CombineTo(ST1, ST1->getChain()); + return SDValue(); + } } } } @@ -20183,6 +21086,99 @@ SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) { return DAG.getBitcast(VT, Shuf); } +// Combine insert(shuffle(load, <u,0,1,2>), load, 0) into a single load if +// possible and the new load will be quick. We use more loads but less shuffles +// and inserts. +SDValue DAGCombiner::combineInsertEltToLoad(SDNode *N, unsigned InsIndex) { + EVT VT = N->getValueType(0); + + // InsIndex is expected to be the first of last lane. + if (!VT.isFixedLengthVector() || + (InsIndex != 0 && InsIndex != VT.getVectorNumElements() - 1)) + return SDValue(); + + // Look for a shuffle with the mask u,0,1,2,3,4,5,6 or 1,2,3,4,5,6,7,u + // depending on the InsIndex. + auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(N->getOperand(0)); + SDValue Scalar = N->getOperand(1); + if (!Shuffle || !all_of(enumerate(Shuffle->getMask()), [&](auto P) { + return InsIndex == P.index() || P.value() < 0 || + (InsIndex == 0 && P.value() == (int)P.index() - 1) || + (InsIndex == VT.getVectorNumElements() - 1 && + P.value() == (int)P.index() + 1); + })) + return SDValue(); + + // We optionally skip over an extend so long as both loads are extended in the + // same way from the same type. + unsigned Extend = 0; + if (Scalar.getOpcode() == ISD::ZERO_EXTEND || + Scalar.getOpcode() == ISD::SIGN_EXTEND || + Scalar.getOpcode() == ISD::ANY_EXTEND) { + Extend = Scalar.getOpcode(); + Scalar = Scalar.getOperand(0); + } + + auto *ScalarLoad = dyn_cast<LoadSDNode>(Scalar); + if (!ScalarLoad) + return SDValue(); + + SDValue Vec = Shuffle->getOperand(0); + if (Extend) { + if (Vec.getOpcode() != Extend) + return SDValue(); + Vec = Vec.getOperand(0); + } + auto *VecLoad = dyn_cast<LoadSDNode>(Vec); + if (!VecLoad || Vec.getValueType().getScalarType() != Scalar.getValueType()) + return SDValue(); + + int EltSize = ScalarLoad->getValueType(0).getScalarSizeInBits(); + if (EltSize == 0 || EltSize % 8 != 0 || !ScalarLoad->isSimple() || + !VecLoad->isSimple() || VecLoad->getExtensionType() != ISD::NON_EXTLOAD || + ScalarLoad->getExtensionType() != ISD::NON_EXTLOAD || + ScalarLoad->getAddressSpace() != VecLoad->getAddressSpace()) + return SDValue(); + + // Check that the offset between the pointers to produce a single continuous + // load. + if (InsIndex == 0) { + if (!DAG.areNonVolatileConsecutiveLoads(ScalarLoad, VecLoad, EltSize / 8, + -1)) + return SDValue(); + } else { + if (!DAG.areNonVolatileConsecutiveLoads( + VecLoad, ScalarLoad, VT.getVectorNumElements() * EltSize / 8, -1)) + return SDValue(); + } + + // And that the new unaligned load will be fast. + unsigned IsFast = 0; + Align NewAlign = commonAlignment(VecLoad->getAlign(), EltSize / 8); + if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), + Vec.getValueType(), VecLoad->getAddressSpace(), + NewAlign, VecLoad->getMemOperand()->getFlags(), + &IsFast) || + !IsFast) + return SDValue(); + + // Calculate the new Ptr and create the new load. + SDLoc DL(N); + SDValue Ptr = ScalarLoad->getBasePtr(); + if (InsIndex != 0) + Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), VecLoad->getBasePtr(), + DAG.getConstant(EltSize / 8, DL, Ptr.getValueType())); + MachinePointerInfo PtrInfo = + InsIndex == 0 ? ScalarLoad->getPointerInfo() + : VecLoad->getPointerInfo().getWithOffset(EltSize / 8); + + SDValue Load = DAG.getLoad(VecLoad->getValueType(0), DL, + ScalarLoad->getChain(), Ptr, PtrInfo, NewAlign); + DAG.makeEquivalentMemoryOrdering(ScalarLoad, Load.getValue(1)); + DAG.makeEquivalentMemoryOrdering(VecLoad, Load.getValue(1)); + return Extend ? DAG.getNode(Extend, DL, VT, Load) : Load; +} + SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { SDValue InVec = N->getOperand(0); SDValue InVal = N->getOperand(1); @@ -20254,6 +21250,9 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { if (SDValue Shuf = combineInsertEltToShuffle(N, Elt)) return Shuf; + if (SDValue Shuf = combineInsertEltToLoad(N, Elt)) + return Shuf; + // Attempt to convert an insert_vector_elt chain into a legal build_vector. if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) { // vXi1 vector - we don't need to recurse. @@ -20349,6 +21348,20 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { return NewShuffle; } + // If all insertions are zero value, try to convert to AND mask. + // TODO: Do this for -1 with OR mask? + if (!LegalOperations && llvm::isNullConstant(InVal) && + all_of(Ops, [InVal](SDValue Op) { return !Op || Op == InVal; }) && + count_if(Ops, [InVal](SDValue Op) { return Op == InVal; }) >= 2) { + SDValue Zero = DAG.getConstant(0, DL, MaxEltVT); + SDValue AllOnes = DAG.getAllOnesConstant(DL, MaxEltVT); + SmallVector<SDValue, 8> Mask(NumElts); + for (unsigned I = 0; I != NumElts; ++I) + Mask[I] = Ops[I] ? Zero : AllOnes; + return DAG.getNode(ISD::AND, DL, VT, CurVec, + DAG.getBuildVector(VT, DL, Mask)); + } + // Failed to find a match in the chain - bail. break; } @@ -20701,8 +21714,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { // extract_vector_elt (build_vector x, y), 1 -> y if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) || VecOp.getOpcode() == ISD::SPLAT_VECTOR) && - TLI.isTypeLegal(VecVT) && - (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) { + TLI.isTypeLegal(VecVT)) { assert((VecOp.getOpcode() != ISD::BUILD_VECTOR || VecVT.isFixedLengthVector()) && "BUILD_VECTOR used for scalable vectors"); @@ -20711,12 +21723,15 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { SDValue Elt = VecOp.getOperand(IndexVal); EVT InEltVT = Elt.getValueType(); - // Sometimes build_vector's scalar input types do not match result type. - if (ScalarVT == InEltVT) - return Elt; + if (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT) || + isNullConstant(Elt)) { + // Sometimes build_vector's scalar input types do not match result type. + if (ScalarVT == InEltVT) + return Elt; - // TODO: It may be useful to truncate if free if the build_vector implicitly - // converts. + // TODO: It may be useful to truncate if free if the build_vector + // implicitly converts. + } } if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations)) @@ -21025,9 +22040,10 @@ SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) { // same source type and all of the inputs must be any or zero extend. // Scalar sizes must be a power of two. EVT OutScalarTy = VT.getScalarType(); - bool ValidTypes = SourceType != MVT::Other && - isPowerOf2_32(OutScalarTy.getSizeInBits()) && - isPowerOf2_32(SourceType.getSizeInBits()); + bool ValidTypes = + SourceType != MVT::Other && + llvm::has_single_bit<uint32_t>(OutScalarTy.getSizeInBits()) && + llvm::has_single_bit<uint32_t>(SourceType.getSizeInBits()); // Create a new simpler BUILD_VECTOR sequence which other optimizations can // turn into a single shuffle instruction. @@ -21157,7 +22173,7 @@ SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) { } // Only cast if the size is the same - if (Src.getValueType().getSizeInBits() != VT.getSizeInBits()) + if (!Src || Src.getValueType().getSizeInBits() != VT.getSizeInBits()) return SDValue(); return DAG.getBitcast(VT, Src); @@ -21359,10 +22375,9 @@ static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) { // the source vector. The high bits map to zero. We will use a zero vector // as the 2nd source operand of the shuffle, so use the 1st element of // that vector (mask value is number-of-elements) for the high bits. - if (i % ZextRatio == 0) - ShufMask[i] = Extract.getConstantOperandVal(1); - else - ShufMask[i] = NumMaskElts; + int Low = DAG.getDataLayout().isBigEndian() ? (ZextRatio - 1) : 0; + ShufMask[i] = (i % ZextRatio == Low) ? Extract.getConstantOperandVal(1) + : NumMaskElts; } // Undef elements of the build vector remain undef because we initialize @@ -21917,7 +22932,7 @@ static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) { EVT OpVT = N->getOperand(0).getValueType(); // If the operands are legal vectors, leave them alone. - if (TLI.isTypeLegal(OpVT)) + if (TLI.isTypeLegal(OpVT) || OpVT.isScalableVector()) return SDValue(); SDLoc DL(N); @@ -22273,7 +23288,13 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { // If the input is a concat_vectors, just make a larger concat by padding // with smaller undefs. - if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse()) { + // + // Legalizing in AArch64TargetLowering::LowerCONCAT_VECTORS() and combining + // here could cause an infinite loop. That legalizing happens when LegalDAG + // is true and input of AArch64TargetLowering::LowerCONCAT_VECTORS() is + // scalable. + if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse() && + !(LegalDAG && In.getValueType().isScalableVector())) { unsigned NumOps = N->getNumOperands() * In.getNumOperands(); SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end()); Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType())); @@ -22767,10 +23788,6 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N, continue; } - // Profitability check: only deal with extractions from the first subvector. - if (OpSubvecIdx != 0) - return SDValue(); - const std::pair<SDValue, int> DemandedSubvector = std::make_pair(Op, OpSubvecIdx); @@ -22800,6 +23817,14 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N, if (DemandedSubvectors.empty()) return DAG.getUNDEF(NarrowVT); + // Profitability check: only deal with extractions from the first subvector + // unless the mask becomes an identity mask. + if (!ShuffleVectorInst::isIdentityMask(NewMask) || + any_of(NewMask, [](int M) { return M < 0; })) + for (auto &DemandedSubvector : DemandedSubvectors) + if (DemandedSubvector.second != 0) + return SDValue(); + // We still perform the exact same EXTRACT_SUBVECTOR, just on different // operand[s]/index[es], so there is no point in checking for it's legality. @@ -22975,7 +24000,7 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { if (NumElems == 1) { SDValue Src = V->getOperand(IdxVal); if (EltVT != Src.getValueType()) - Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src); + Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), EltVT, Src); return DAG.getBitcast(NVT, Src); } @@ -23450,9 +24475,7 @@ static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN, SDValue N0 = peekThroughBitcasts(SVN->getOperand(0)); unsigned Opcode = N0.getOpcode(); - if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG && - Opcode != ISD::SIGN_EXTEND_VECTOR_INREG && - Opcode != ISD::ZERO_EXTEND_VECTOR_INREG) + if (!ISD::isExtVecInRegOpcode(Opcode)) return SDValue(); SDValue N00 = N0.getOperand(0); @@ -23518,7 +24541,7 @@ static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf, assert((unsigned)Idx < NumElts && "Out-of-bounds shuffle indice?"); DemandedElts.setBit(Idx); } - assert(DemandedElts.countPopulation() > 1 && "Is a splat shuffle already?"); + assert(DemandedElts.popcount() > 1 && "Is a splat shuffle already?"); APInt UndefElts; if (DAG.isSplatValue(Shuf->getOperand(0), DemandedElts, UndefElts)) { // Even if all demanded elements are splat, some of them could be undef. @@ -24072,8 +25095,8 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { bool IsInLaneMask = true; ArrayRef<int> Mask = SVN->getMask(); SmallVector<int, 16> ClearMask(NumElts, -1); - APInt DemandedLHS = APInt::getNullValue(NumElts); - APInt DemandedRHS = APInt::getNullValue(NumElts); + APInt DemandedLHS = APInt::getZero(NumElts); + APInt DemandedRHS = APInt::getZero(NumElts); for (int I = 0; I != (int)NumElts; ++I) { int M = Mask[I]; if (M < 0) @@ -24086,12 +25109,9 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { } } // TODO: Should we try to mask with N1 as well? - if (!IsInLaneMask && - (!DemandedLHS.isNullValue() || !DemandedRHS.isNullValue()) && - (DemandedLHS.isNullValue() || - DAG.MaskedVectorIsZero(N0, DemandedLHS)) && - (DemandedRHS.isNullValue() || - DAG.MaskedVectorIsZero(N1, DemandedRHS))) { + if (!IsInLaneMask && (!DemandedLHS.isZero() || !DemandedRHS.isZero()) && + (DemandedLHS.isZero() || DAG.MaskedVectorIsZero(N0, DemandedLHS)) && + (DemandedRHS.isZero() || DAG.MaskedVectorIsZero(N1, DemandedRHS))) { SDLoc DL(N); EVT IntVT = VT.changeVectorElementTypeToInteger(); EVT IntSVT = VT.getVectorElementType().changeTypeToInteger(); @@ -24771,6 +25791,17 @@ SDValue DAGCombiner::visitVECREDUCE(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitVP_FSUB(SDNode *N) { + SelectionDAG::FlagInserter FlagsInserter(DAG, N); + + // FSUB -> FMA combines: + if (SDValue Fused = visitFSUBForFMACombine<VPMatchContext>(N)) { + AddToWorklist(Fused.getNode()); + return Fused; + } + return SDValue(); +} + SDValue DAGCombiner::visitVPOp(SDNode *N) { if (N->getOpcode() == ISD::VP_GATHER) @@ -24792,8 +25823,17 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) { ISD::isConstantSplatVectorAllZeros(N->getOperand(*MaskIdx).getNode()); // This is the only generic VP combine we support for now. - if (!AreAllEltsDisabled) + if (!AreAllEltsDisabled) { + switch (N->getOpcode()) { + case ISD::VP_FADD: + return visitVP_FADD(N); + case ISD::VP_FSUB: + return visitVP_FSUB(N); + case ISD::VP_FMA: + return visitFMA<VPMatchContext>(N); + } return SDValue(); + } // Binary operations can be replaced by UNDEF. if (ISD::isVPBinaryOp(N->getOpcode())) @@ -24814,6 +25854,97 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitGET_FPENV_MEM(SDNode *N) { + SDValue Chain = N->getOperand(0); + SDValue Ptr = N->getOperand(1); + EVT MemVT = cast<FPStateAccessSDNode>(N)->getMemoryVT(); + + // Check if the memory, where FP state is written to, is used only in a single + // load operation. + LoadSDNode *LdNode = nullptr; + for (auto *U : Ptr->uses()) { + if (U == N) + continue; + if (auto *Ld = dyn_cast<LoadSDNode>(U)) { + if (LdNode && LdNode != Ld) + return SDValue(); + LdNode = Ld; + continue; + } + return SDValue(); + } + if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() || + !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT || + !LdNode->getChain().reachesChainWithoutSideEffects(SDValue(N, 0))) + return SDValue(); + + // Check if the loaded value is used only in a store operation. + StoreSDNode *StNode = nullptr; + for (auto I = LdNode->use_begin(), E = LdNode->use_end(); I != E; ++I) { + SDUse &U = I.getUse(); + if (U.getResNo() == 0) { + if (auto *St = dyn_cast<StoreSDNode>(U.getUser())) { + if (StNode) + return SDValue(); + StNode = St; + } else { + return SDValue(); + } + } + } + if (!StNode || !StNode->isSimple() || StNode->isIndexed() || + !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT || + !StNode->getChain().reachesChainWithoutSideEffects(SDValue(LdNode, 1))) + return SDValue(); + + // Create new node GET_FPENV_MEM, which uses the store address to write FP + // environment. + SDValue Res = DAG.getGetFPEnv(Chain, SDLoc(N), StNode->getBasePtr(), MemVT, + StNode->getMemOperand()); + CombineTo(StNode, Res, false); + return Res; +} + +SDValue DAGCombiner::visitSET_FPENV_MEM(SDNode *N) { + SDValue Chain = N->getOperand(0); + SDValue Ptr = N->getOperand(1); + EVT MemVT = cast<FPStateAccessSDNode>(N)->getMemoryVT(); + + // Check if the address of FP state is used also in a store operation only. + StoreSDNode *StNode = nullptr; + for (auto *U : Ptr->uses()) { + if (U == N) + continue; + if (auto *St = dyn_cast<StoreSDNode>(U)) { + if (StNode && StNode != St) + return SDValue(); + StNode = St; + continue; + } + return SDValue(); + } + if (!StNode || !StNode->isSimple() || StNode->isIndexed() || + !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT || + !Chain.reachesChainWithoutSideEffects(SDValue(StNode, 0))) + return SDValue(); + + // Check if the stored value is loaded from some location and the loaded + // value is used only in the store operation. + SDValue StValue = StNode->getValue(); + auto *LdNode = dyn_cast<LoadSDNode>(StValue); + if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() || + !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT || + !StNode->getChain().reachesChainWithoutSideEffects(SDValue(LdNode, 1))) + return SDValue(); + + // Create new node SET_FPENV_MEM, which uses the load address to read FP + // environment. + SDValue Res = + DAG.getSetFPEnv(LdNode->getChain(), SDLoc(N), LdNode->getBasePtr(), MemVT, + LdNode->getMemOperand()); + return Res; +} + /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle /// with the destination vector and a zero vector. /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==> @@ -24960,8 +26091,6 @@ SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) { unsigned Opcode = N->getOpcode(); SDValue N0 = N->getOperand(0); - EVT SrcVT = N0->getValueType(0); - EVT SrcEltVT = SrcVT.getVectorElementType(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); // TODO: promote operation might be also good here? @@ -24971,7 +26100,9 @@ SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) { (N0.getOpcode() == ISD::SPLAT_VECTOR || TLI.isExtractVecEltCheap(VT, Index0)) && TLI.isOperationLegalOrCustom(Opcode, EltVT) && - TLI.preferScalarizeSplat(Opcode)) { + TLI.preferScalarizeSplat(N)) { + EVT SrcVT = N0.getValueType(); + EVT SrcEltVT = SrcVT.getVectorElementType(); SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL); SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SrcEltVT, Src0, IndexC); @@ -25588,14 +26719,14 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1, N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) { SDValue AndLHS = N0->getOperand(0); auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1)); - if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) { + if (ConstAndRHS && ConstAndRHS->getAPIntValue().popcount() == 1) { // Shift the tested bit over the sign bit. const APInt &AndMask = ConstAndRHS->getAPIntValue(); unsigned ShCt = AndMask.getBitWidth() - 1; if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) { SDValue ShlAmt = - DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS), - getShiftAmountTy(AndLHS.getValueType())); + DAG.getConstant(AndMask.countl_zero(), SDLoc(AndLHS), + getShiftAmountTy(AndLHS.getValueType())); SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt); // Now arithmetic right shift it all the way over, so the result is @@ -25991,7 +27122,7 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, Reciprocal)) { AddToWorklist(Est.getNode()); - if (Iterations) + if (Iterations > 0) Est = UseOneConstNR ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal) : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal); @@ -26334,7 +27465,7 @@ bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) { // BaseIndexOffset assumes that offsets are fixed-size, which // is not valid for scalable vectors where the offsets are // scaled by `vscale`, so bail out early. - if (St->getMemoryVT().isScalableVector()) + if (St->getMemoryVT().isScalableVT()) return false; // Add ST's interval. |
