diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 1896 |
1 files changed, 1272 insertions, 624 deletions
diff --git a/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index ec297579090e..aa688d9dda3c 100644 --- a/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -35,7 +35,6 @@ #include "llvm/Analysis/VectorUtils.h" #include "llvm/CodeGen/DAGCombine.h" #include "llvm/CodeGen/ISDOpcodes.h" -#include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineMemOperand.h" #include "llvm/CodeGen/RuntimeLibcalls.h" @@ -52,7 +51,6 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CodeGen.h" @@ -426,6 +424,7 @@ namespace { SDValue visitREM(SDNode *N); SDValue visitMULHU(SDNode *N); SDValue visitMULHS(SDNode *N); + SDValue visitAVG(SDNode *N); SDValue visitSMUL_LOHI(SDNode *N); SDValue visitUMUL_LOHI(SDNode *N); SDValue visitMULO(SDNode *N); @@ -511,6 +510,7 @@ namespace { SDValue visitMSCATTER(SDNode *N); SDValue visitFP_TO_FP16(SDNode *N); SDValue visitFP16_TO_FP(SDNode *N); + SDValue visitFP_TO_BF16(SDNode *N); SDValue visitVECREDUCE(SDNode *N); SDValue visitVPOp(SDNode *N); @@ -520,7 +520,9 @@ namespace { SDValue XformToShuffleWithZero(SDNode *N); bool reassociationCanBreakAddressingModePattern(unsigned Opc, - const SDLoc &DL, SDValue N0, + const SDLoc &DL, + SDNode *N, + SDValue N0, SDValue N1); SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0, SDValue N1); @@ -570,6 +572,8 @@ namespace { SDValue BuildSDIV(SDNode *N); SDValue BuildSDIVPow2(SDNode *N); SDValue BuildUDIV(SDNode *N); + SDValue BuildSREMPow2(SDNode *N); + SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N); SDValue BuildLogBase2(SDValue V, const SDLoc &DL); SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags); SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags); @@ -583,11 +587,11 @@ namespace { bool DemandHighBits = true); SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1); SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg, - SDValue InnerPos, SDValue InnerNeg, + SDValue InnerPos, SDValue InnerNeg, bool HasPos, unsigned PosOpcode, unsigned NegOpcode, const SDLoc &DL); SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg, - SDValue InnerPos, SDValue InnerNeg, + SDValue InnerPos, SDValue InnerNeg, bool HasPos, unsigned PosOpcode, unsigned NegOpcode, const SDLoc &DL); SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL); @@ -665,9 +669,8 @@ namespace { /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2). /// MulNode is the original multiply, AddNode is (add x, c1), /// and ConstNode is c2. - bool isMulAddWithConstProfitable(SDNode *MulNode, - SDValue &AddNode, - SDValue &ConstNode); + bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode, + SDValue ConstNode); /// This is a helper function for visitAND and visitZERO_EXTEND. Returns /// true if the (and (load x) c) pattern matches an extload. ExtVT returns @@ -880,8 +883,8 @@ void DAGCombiner::deleteAndRecombine(SDNode *N) { // We provide an Offset so that we can create bitwidths that won't overflow. static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) { unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth()); - LHS = LHS.zextOrSelf(Bits); - RHS = RHS.zextOrSelf(Bits); + LHS = LHS.zext(Bits); + RHS = RHS.zext(Bits); } // Return true if this node is a setcc, or is a select_cc @@ -926,7 +929,7 @@ bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS, /// it is profitable to do so. bool DAGCombiner::isOneUseSetCC(SDValue N) const { SDValue N0, N1, N2; - if (isSetCCEquivalent(N, N0, N1, N2) && N.getNode()->hasOneUse()) + if (isSetCCEquivalent(N, N0, N1, N2) && N->hasOneUse()) return true; return false; } @@ -996,6 +999,7 @@ static bool canSplitIdx(LoadSDNode *LD) { bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc, const SDLoc &DL, + SDNode *N, SDValue N0, SDValue N1) { // Currently this only tries to ensure we don't undo the GEP splits done by @@ -1004,33 +1008,62 @@ bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc, // (load/store (add, (add, x, offset1), offset2)) -> // (load/store (add, x, offset1+offset2)). - if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD) - return false; + // (load/store (add, (add, x, y), offset2)) -> + // (load/store (add, (add, x, offset2), y)). - if (N0.hasOneUse()) + if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD) return false; - auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1)); auto *C2 = dyn_cast<ConstantSDNode>(N1); - if (!C1 || !C2) + if (!C2) return false; - const APInt &C1APIntVal = C1->getAPIntValue(); const APInt &C2APIntVal = C2->getAPIntValue(); - if (C1APIntVal.getBitWidth() > 64 || C2APIntVal.getBitWidth() > 64) + if (C2APIntVal.getSignificantBits() > 64) return false; - const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal; - if (CombinedValueIntVal.getBitWidth() > 64) - return false; - const int64_t CombinedValue = CombinedValueIntVal.getSExtValue(); - - for (SDNode *Node : N0->uses()) { - auto LoadStore = dyn_cast<MemSDNode>(Node); - if (LoadStore) { - // Is x[offset2] already not a legal addressing mode? If so then - // reassociating the constants breaks nothing (we test offset2 because - // that's the one we hope to fold into the load or store). + if (auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) { + if (N0.hasOneUse()) + return false; + + const APInt &C1APIntVal = C1->getAPIntValue(); + const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal; + if (CombinedValueIntVal.getSignificantBits() > 64) + return false; + const int64_t CombinedValue = CombinedValueIntVal.getSExtValue(); + + for (SDNode *Node : N->uses()) { + if (auto *LoadStore = dyn_cast<MemSDNode>(Node)) { + // Is x[offset2] already not a legal addressing mode? If so then + // reassociating the constants breaks nothing (we test offset2 because + // that's the one we hope to fold into the load or store). + TargetLoweringBase::AddrMode AM; + AM.HasBaseReg = true; + AM.BaseOffs = C2APIntVal.getSExtValue(); + EVT VT = LoadStore->getMemoryVT(); + unsigned AS = LoadStore->getAddressSpace(); + Type *AccessTy = VT.getTypeForEVT(*DAG.getContext()); + if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS)) + continue; + + // Would x[offset1+offset2] still be a legal addressing mode? + AM.BaseOffs = CombinedValue; + if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS)) + return true; + } + } + } else { + if (auto *GA = dyn_cast<GlobalAddressSDNode>(N0.getOperand(1))) + if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA)) + return false; + + for (SDNode *Node : N->uses()) { + auto *LoadStore = dyn_cast<MemSDNode>(Node); + if (!LoadStore) + return false; + + // Is x[offset2] a legal addressing mode? If so then + // reassociating the constants breaks address pattern TargetLoweringBase::AddrMode AM; AM.HasBaseReg = true; AM.BaseOffs = C2APIntVal.getSExtValue(); @@ -1038,13 +1071,9 @@ bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc, unsigned AS = LoadStore->getAddressSpace(); Type *AccessTy = VT.getTypeForEVT(*DAG.getContext()); if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS)) - continue; - - // Would x[offset1+offset2] still be a legal addressing mode? - AM.BaseOffs = CombinedValue; - if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS)) - return true; + return false; } + return true; } return false; @@ -1072,11 +1101,51 @@ SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, if (TLI.isReassocProfitable(DAG, N0, N1)) { // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1) // iff (op x, c1) has one use - if (SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1)) - return DAG.getNode(Opc, DL, VT, OpNode, N01); - return SDValue(); + SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1); + return DAG.getNode(Opc, DL, VT, OpNode, N01); + } + } + + // Check for repeated operand logic simplifications. + if (Opc == ISD::AND || Opc == ISD::OR) { + // (N00 & N01) & N00 --> N00 & N01 + // (N00 & N01) & N01 --> N00 & N01 + // (N00 | N01) | N00 --> N00 | N01 + // (N00 | N01) | N01 --> N00 | N01 + if (N1 == N00 || N1 == N01) + return N0; + } + if (Opc == ISD::XOR) { + // (N00 ^ N01) ^ N00 --> N01 + if (N1 == N00) + return N01; + // (N00 ^ N01) ^ N01 --> N00 + if (N1 == N01) + return N00; + } + + if (TLI.isReassocProfitable(DAG, N0, N1)) { + if (N1 != N01) { + // Reassociate if (op N00, N1) already exist + if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N00, N1})) { + // if Op (Op N00, N1), N01 already exist + // we need to stop reassciate to avoid dead loop + if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N01})) + return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N01); + } + } + + if (N1 != N00) { + // Reassociate if (op N01, N1) already exist + if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N01, N1})) { + // if Op (Op N01, N1), N00 already exist + // we need to stop reassciate to avoid dead loop + if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N00})) + return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N00); + } } } + return SDValue(); } @@ -1103,7 +1172,7 @@ SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo, assert(N->getNumValues() == NumTo && "Broken CombineTo call!"); ++NodesCombined; LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: "; - To[0].getNode()->dump(&DAG); + To[0].dump(&DAG); dbgs() << " and " << NumTo - 1 << " other values\n"); for (unsigned i = 0, e = NumTo; i != e; ++i) assert((!To[i].getNode() || @@ -1115,10 +1184,8 @@ SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo, if (AddTo) { // Push the new nodes and any users onto the worklist for (unsigned i = 0, e = NumTo; i != e; ++i) { - if (To[i].getNode()) { - AddToWorklist(To[i].getNode()); - AddUsersToWorklist(To[i].getNode()); - } + if (To[i].getNode()) + AddToWorklistWithUsers(To[i].getNode()); } } @@ -1134,9 +1201,8 @@ void DAGCombiner:: CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) { // Replace the old value with the new one. ++NodesCombined; - LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.getNode()->dump(&DAG); - dbgs() << "\nWith: "; TLO.New.getNode()->dump(&DAG); - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG); + dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n'); // Replace all uses. If any nodes become isomorphic to other nodes and // are deleted, make sure to remove them from our worklist. @@ -1149,7 +1215,7 @@ CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) { // Finally, if the node is now dead, remove it from the graph. The node // may not be dead if the replacement process recursively simplified to // something else needing this node. - if (TLO.Old.getNode()->use_empty()) + if (TLO.Old->use_empty()) deleteAndRecombine(TLO.Old.getNode()); } @@ -1196,7 +1262,7 @@ void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) { SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0)); LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: "; - Trunc.getNode()->dump(&DAG); dbgs() << '\n'); + Trunc.dump(&DAG); dbgs() << '\n'); WorklistRemover DeadNodes(*this); DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc); DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1)); @@ -1295,7 +1361,7 @@ SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) { if (TLI.IsDesirableToPromoteOp(Op, PVT)) { assert(PVT != VT && "Don't know what type to promote to!"); - LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG)); + LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG)); bool Replace0 = false; SDValue N0 = Op.getOperand(0); @@ -1322,7 +1388,7 @@ SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) { // If operands have a use ordering, make sure we deal with // predecessor first. - if (Replace0 && Replace1 && N0.getNode()->isPredecessorOf(N1.getNode())) { + if (Replace0 && Replace1 && N0->isPredecessorOf(N1.getNode())) { std::swap(N0, N1); std::swap(NN0, NN1); } @@ -1363,11 +1429,10 @@ SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) { if (TLI.IsDesirableToPromoteOp(Op, PVT)) { assert(PVT != VT && "Don't know what type to promote to!"); - LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG)); + LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG)); bool Replace = false; SDValue N0 = Op.getOperand(0); - SDValue N1 = Op.getOperand(1); if (Opc == ISD::SRA) N0 = SExtPromoteOperand(N0, PVT); else if (Opc == ISD::SRL) @@ -1379,6 +1444,7 @@ SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) { return SDValue(); SDLoc DL(Op); + SDValue N1 = Op.getOperand(1); SDValue RV = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1)); @@ -1414,7 +1480,7 @@ SDValue DAGCombiner::PromoteExtend(SDValue Op) { // fold (aext (aext x)) -> (aext x) // fold (aext (zext x)) -> (zext x) // fold (aext (sext x)) -> (sext x) - LLVM_DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG)); + LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG)); return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0)); } return SDValue(); @@ -1455,7 +1521,7 @@ bool DAGCombiner::PromoteLoad(SDValue Op) { SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD); LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: "; - Result.getNode()->dump(&DAG); dbgs() << '\n'); + Result.dump(&DAG); dbgs() << '\n'); WorklistRemover DeadNodes(*this); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1)); @@ -1569,9 +1635,9 @@ void DAGCombiner::Run(CombineLevel AtLevel) { RV.getOpcode() != ISD::DELETED_NODE && "Node was deleted but visit returned new node!"); - LLVM_DEBUG(dbgs() << " ... into: "; RV.getNode()->dump(&DAG)); + LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG)); - if (N->getNumValues() == RV.getNode()->getNumValues()) + if (N->getNumValues() == RV->getNumValues()) DAG.ReplaceAllUsesWith(N, RV.getNode()); else { assert(N->getValueType(0) == RV.getValueType() && @@ -1635,6 +1701,10 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::UREM: return visitREM(N); case ISD::MULHU: return visitMULHU(N); case ISD::MULHS: return visitMULHS(N); + case ISD::AVGFLOORS: + case ISD::AVGFLOORU: + case ISD::AVGCEILS: + case ISD::AVGCEILU: return visitAVG(N); case ISD::SMUL_LOHI: return visitSMUL_LOHI(N); case ISD::UMUL_LOHI: return visitUMUL_LOHI(N); case ISD::SMULO: @@ -1724,6 +1794,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::LIFETIME_END: return visitLIFETIME_END(N); case ISD::FP_TO_FP16: return visitFP_TO_FP16(N); case ISD::FP16_TO_FP: return visitFP16_TO_FP(N); + case ISD::FP_TO_BF16: return visitFP_TO_BF16(N); case ISD::FREEZE: return visitFREEZE(N); case ISD::VECREDUCE_FADD: case ISD::VECREDUCE_FMUL: @@ -2072,8 +2143,9 @@ static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG, return false; VT = ST->getMemoryVT(); AS = ST->getAddressSpace(); - } else + } else { return false; + } TargetLowering::AddrMode AM; if (N->getOpcode() == ISD::ADD) { @@ -2094,8 +2166,9 @@ static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG, else // [reg +/- reg] AM.Scale = 1; - } else + } else { return false; + } return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, VT.getTypeForEVT(*DAG.getContext()), AS); @@ -2139,6 +2212,18 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG, return C->isExactlyValue(1.0); } } + if (ConstantSDNode *C = isConstOrConstSplat(V)) { + switch (Opcode) { + case ISD::ADD: // X + 0 --> X + case ISD::SUB: // X - 0 --> X + case ISD::SHL: // X << 0 --> X + case ISD::SRA: // X s>> 0 --> X + case ISD::SRL: // X u>> 0 --> X + return C->isZero(); + case ISD::MUL: // X * 1 --> X + return C->isOne(); + } + } return false; }; @@ -2316,6 +2401,15 @@ static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) { return SDValue(); } +static bool isADDLike(SDValue V, const SelectionDAG &DAG) { + unsigned Opcode = V.getOpcode(); + if (Opcode == ISD::OR) + return DAG.haveNoCommonBitsSet(V.getOperand(0), V.getOperand(1)); + if (Opcode == ISD::XOR) + return isMinSignedConstant(V.getOperand(1)); + return false; +} + /// Try to fold a node that behaves like an ADD (note that N isn't necessarily /// an ISD::ADD here, it could for example be an ISD::OR if we know that there /// are no common bits set in the operands). @@ -2354,66 +2448,60 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { if (isNullConstant(N1)) return N0; - if (isConstantOrConstantVector(N1, /* NoOpaque */ true)) { + if (N0.getOpcode() == ISD::SUB) { + SDValue N00 = N0.getOperand(0); + SDValue N01 = N0.getOperand(1); + // fold ((A-c1)+c2) -> (A+(c2-c1)) - if (N0.getOpcode() == ISD::SUB && - isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) { - SDValue Sub = - DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N0.getOperand(1)}); - assert(Sub && "Constant folding failed"); + if (SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N01})) return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub); - } // fold ((c1-A)+c2) -> (c1+c2)-A - if (N0.getOpcode() == ISD::SUB && - isConstantOrConstantVector(N0.getOperand(0), /* NoOpaque */ true)) { - SDValue Add = - DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N0.getOperand(0)}); - assert(Add && "Constant folding failed"); + if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N00})) return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1)); - } + } - // add (sext i1 X), 1 -> zext (not i1 X) - // We don't transform this pattern: - // add (zext i1 X), -1 -> sext (not i1 X) - // because most (?) targets generate better code for the zext form. - if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() && - isOneOrOneSplat(N1)) { - SDValue X = N0.getOperand(0); - if ((!LegalOperations || - (TLI.isOperationLegal(ISD::XOR, X.getValueType()) && - TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) && - X.getScalarValueSizeInBits() == 1) { - SDValue Not = DAG.getNOT(DL, X, X.getValueType()); - return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not); - } + // add (sext i1 X), 1 -> zext (not i1 X) + // We don't transform this pattern: + // add (zext i1 X), -1 -> sext (not i1 X) + // because most (?) targets generate better code for the zext form. + if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() && + isOneOrOneSplat(N1)) { + SDValue X = N0.getOperand(0); + if ((!LegalOperations || + (TLI.isOperationLegal(ISD::XOR, X.getValueType()) && + TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) && + X.getScalarValueSizeInBits() == 1) { + SDValue Not = DAG.getNOT(DL, X, X.getValueType()); + return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not); } + } - // Fold (add (or x, c0), c1) -> (add x, (c0 + c1)) if (or x, c0) is - // equivalent to (add x, c0). - if (N0.getOpcode() == ISD::OR && - isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true) && - DAG.haveNoCommonBitsSet(N0.getOperand(0), N0.getOperand(1))) { - if (SDValue Add0 = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, - {N1, N0.getOperand(1)})) - return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add0); - } + // Fold (add (or x, c0), c1) -> (add x, (c0 + c1)) + // iff (or x, c0) is equivalent to (add x, c0). + // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1)) + // iff (xor x, c0) is equivalent to (add x, c0). + if (isADDLike(N0, DAG)) { + SDValue N01 = N0.getOperand(1); + if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N01})) + return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add); } if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; // reassociate add - if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N0, N1)) { + if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N, N0, N1)) { if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags())) return RADD; // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is // equivalent to (add x, c). + // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is + // equivalent to (add x, c). auto ReassociateAddOr = [&](SDValue N0, SDValue N1) { - if (N0.getOpcode() == ISD::OR && N0.hasOneUse() && - isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true) && - DAG.haveNoCommonBitsSet(N0.getOperand(0), N0.getOperand(1))) { + if (isADDLike(N0, DAG) && N0.hasOneUse() && + isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) { return DAG.getNode(ISD::ADD, DL, VT, DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)), N0.getOperand(1)); @@ -2473,7 +2561,8 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { N1.getOperand(1)); // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant - if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB) { + if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB && + N0->hasOneUse() && N1->hasOneUse()) { SDValue N00 = N0.getOperand(0); SDValue N01 = N0.getOperand(1); SDValue N10 = N1.getOperand(0); @@ -2526,8 +2615,8 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { // add (add x, y), 1 // And if the target does not like this form then turn into: // sub y, (xor x, -1) - if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.hasOneUse() && - N0.getOpcode() == ISD::ADD) { + if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD && + N0.hasOneUse()) { SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0), DAG.getAllOnesConstant(DL, VT)); return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not); @@ -2535,7 +2624,7 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { } // (x - y) + -1 -> add (xor y, -1), x - if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB && + if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() && isAllOnesOrAllOnesSplat(N1)) { SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1), N1); return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0)); @@ -2632,7 +2721,8 @@ SDValue DAGCombiner::visitADDSAT(SDNode *N) { // fold vector ops if (VT.isVector()) { - // TODO SimplifyVBinOp + if (SDValue FoldedVOp = SimplifyVBinOp(N, DL)) + return FoldedVOp; // fold (add_sat x, 0) -> x, vector edition if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) @@ -2678,7 +2768,7 @@ static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) { V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO) return SDValue(); - EVT VT = V.getNode()->getValueType(0); + EVT VT = V->getValueType(0); if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT)) return SDValue(); @@ -2731,27 +2821,27 @@ SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1, // add (add x, 1), y // And if the target does not like this form then turn into: // sub y, (xor x, -1) - if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.hasOneUse() && - N0.getOpcode() == ISD::ADD && isOneOrOneSplat(N0.getOperand(1))) { + if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD && + N0.hasOneUse() && isOneOrOneSplat(N0.getOperand(1))) { SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0), DAG.getAllOnesConstant(DL, VT)); return DAG.getNode(ISD::SUB, DL, VT, N1, Not); } - // Hoist one-use subtraction by non-opaque constant: - // (x - C) + y -> (x + y) - C - // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors. - if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB && - isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) { - SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1); - return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1)); - } - // Hoist one-use subtraction from non-opaque constant: - // (C - x) + y -> (y - x) + C - if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB && - isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) { - SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1)); - return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0)); + if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) { + // Hoist one-use subtraction by non-opaque constant: + // (x - C) + y -> (x + y) - C + // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors. + if (isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) { + SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1); + return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1)); + } + // Hoist one-use subtraction from non-opaque constant: + // (C - x) + y -> (y - x) + C + if (isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) { + SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1)); + return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0)); + } } // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1' @@ -3127,21 +3217,26 @@ static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG, // Our goal is to identify A, B, and CarryIn and produce ADDCARRY/SUBCARRY with // a single path for carry/borrow out propagation: static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI, - SDValue Carry0, SDValue Carry1, SDNode *N) { - if (Carry0.getResNo() != 1 || Carry1.getResNo() != 1) + SDValue N0, SDValue N1, SDNode *N) { + SDValue Carry0 = getAsCarry(TLI, N0); + if (!Carry0) return SDValue(); + SDValue Carry1 = getAsCarry(TLI, N1); + if (!Carry1) + return SDValue(); + unsigned Opcode = Carry0.getOpcode(); if (Opcode != Carry1.getOpcode()) return SDValue(); if (Opcode != ISD::UADDO && Opcode != ISD::USUBO) return SDValue(); - // Canonicalize the add/sub of A and B as Carry0 and the add/sub of the - // carry/borrow in as Carry1. (The top and middle uaddo nodes respectively in - // the above ASCII art.) - if (Carry1.getOperand(0) != Carry0.getValue(0) && - Carry1.getOperand(1) != Carry0.getValue(0)) + // Canonicalize the add/sub of A and B (the top node in the above ASCII art) + // as Carry0 and the add/sub of the carry in as Carry1 (the middle node). + if (Carry1.getNode()->isOperandOf(Carry0.getNode())) std::swap(Carry0, Carry1); + + // Check if nodes are connected in expected way. if (Carry1.getOperand(0) != Carry0.getValue(0) && Carry1.getOperand(1) != Carry0.getValue(0)) return SDValue(); @@ -3321,9 +3416,15 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { EVT VT = N0.getValueType(); SDLoc DL(N); + auto PeekThroughFreeze = [](SDValue N) { + if (N->getOpcode() == ISD::FREEZE && N.hasOneUse()) + return N->getOperand(0); + return N; + }; + // fold (sub x, x) -> 0 // FIXME: Refactor this and xor and other similar operations together. - if (N0 == N1) + if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1)) return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations); // fold (sub c1, c2) -> c3 @@ -3381,7 +3482,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { } // Convert 0 - abs(x). - if (N1->getOpcode() == ISD::ABS && + if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() && !TLI.isOperationLegalOrCustom(ISD::ABS, VT)) if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true)) return Result; @@ -3419,44 +3520,31 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { return N0.getOperand(0); // fold (A+C1)-C2 -> A+(C1-C2) - if (N0.getOpcode() == ISD::ADD && - isConstantOrConstantVector(N1, /* NoOpaques */ true) && - isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) { - SDValue NewC = - DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0.getOperand(1), N1}); - assert(NewC && "Constant folding failed"); - return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC); + if (N0.getOpcode() == ISD::ADD) { + SDValue N01 = N0.getOperand(1); + if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N01, N1})) + return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC); } // fold C2-(A+C1) -> (C2-C1)-A if (N1.getOpcode() == ISD::ADD) { SDValue N11 = N1.getOperand(1); - if (isConstantOrConstantVector(N0, /* NoOpaques */ true) && - isConstantOrConstantVector(N11, /* NoOpaques */ true)) { - SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11}); - assert(NewC && "Constant folding failed"); + if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11})) return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0)); - } } // fold (A-C1)-C2 -> A-(C1+C2) - if (N0.getOpcode() == ISD::SUB && - isConstantOrConstantVector(N1, /* NoOpaques */ true) && - isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) { - SDValue NewC = - DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0.getOperand(1), N1}); - assert(NewC && "Constant folding failed"); - return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC); + if (N0.getOpcode() == ISD::SUB) { + SDValue N01 = N0.getOperand(1); + if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N01, N1})) + return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC); } // fold (c1-A)-c2 -> (c1-c2)-A - if (N0.getOpcode() == ISD::SUB && - isConstantOrConstantVector(N1, /* NoOpaques */ true) && - isConstantOrConstantVector(N0.getOperand(0), /* NoOpaques */ true)) { - SDValue NewC = - DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0.getOperand(0), N1}); - assert(NewC && "Constant folding failed"); - return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1)); + if (N0.getOpcode() == ISD::SUB) { + SDValue N00 = N0.getOperand(0); + if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N00, N1})) + return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1)); } // fold ((A+(B+or-C))-B) -> A+or-C @@ -3651,6 +3739,15 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { } } + // As with the previous fold, prefer add for more folding potential. + // Subtracting SMIN/0 is the same as adding SMIN/0: + // N0 - (X << BW-1) --> N0 + (X << BW-1) + if (N1.getOpcode() == ISD::SHL) { + ConstantSDNode *ShlC = isConstOrConstSplat(N1.getOperand(1)); + if (ShlC && ShlC->getAPIntValue() == VT.getScalarSizeInBits() - 1) + return DAG.getNode(ISD::ADD, DL, VT, N1, N0); + } + if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) { // (sub Carry, X) -> (addcarry (sub 0, X), 0, Carry) if (SDValue Carry = getAsCarry(TLI, N0)) { @@ -3686,7 +3783,8 @@ SDValue DAGCombiner::visitSUBSAT(SDNode *N) { // fold vector ops if (VT.isVector()) { - // TODO SimplifyVBinOp + if (SDValue FoldedVOp = SimplifyVBinOp(N, DL)) + return FoldedVOp; // fold (sub_sat x, 0) -> x, vector edition if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) @@ -3837,19 +3935,20 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); + SDLoc DL(N); // fold (mul x, undef) -> 0 if (N0.isUndef() || N1.isUndef()) - return DAG.getConstant(0, SDLoc(N), VT); + return DAG.getConstant(0, DL, VT); // fold (mul c1, c2) -> c1*c2 - if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, SDLoc(N), VT, {N0, N1})) + if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1})) return C; // canonicalize constant to RHS (vector doesn't have to splat) if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && !DAG.isConstantIntBuildVectorOrConstantInt(N1)) - return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0); + return DAG.getNode(ISD::MUL, DL, VT, N1, N0); bool N1IsConst = false; bool N1IsOpaqueConst = false; @@ -3857,7 +3956,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { // fold vector ops if (VT.isVector()) { - if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N))) + if (SDValue FoldedVOp = SimplifyVBinOp(N, DL)) return FoldedVOp; N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1); @@ -3884,17 +3983,14 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { return NewSel; // fold (mul x, -1) -> 0-x - if (N1IsConst && ConstValue1.isAllOnes()) { - SDLoc DL(N); + if (N1IsConst && ConstValue1.isAllOnes()) return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0); - } // fold (mul x, (1 << c)) -> x << c if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) && DAG.isKnownToBeAPowerOfTwo(N1) && (!VT.isVector() || Level <= AfterLegalizeVectorOps)) { - SDLoc DL(N); SDValue LogBase2 = BuildLogBase2(N1, DL); EVT ShiftVT = getShiftAmountTy(N0.getValueType()); SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT); @@ -3904,7 +4000,6 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) { unsigned Log2Val = (-ConstValue1).logBase2(); - SDLoc DL(N); // FIXME: If the input is something that is easily negated (e.g. a // single-use add), we should put the negate there. return DAG.getNode(ISD::SUB, DL, VT, @@ -3949,7 +4044,6 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { ShAmt += TZeros; assert(ShAmt < VT.getScalarSizeInBits() && "multiply-by-constant generated out of bounds shift"); - SDLoc DL(N); SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT)); SDValue R = @@ -3964,12 +4058,10 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { } // (mul (shl X, c1), c2) -> (mul X, c2 << c1) - if (N0.getOpcode() == ISD::SHL && - isConstantOrConstantVector(N1, /* NoOpaques */ true) && - isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) { - SDValue C3 = DAG.getNode(ISD::SHL, SDLoc(N), VT, N1, N0.getOperand(1)); - if (isConstantOrConstantVector(C3)) - return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), C3); + if (N0.getOpcode() == ISD::SHL) { + SDValue N01 = N0.getOperand(1); + if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01})) + return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3); } // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one @@ -3979,18 +4071,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)). if (N0.getOpcode() == ISD::SHL && - isConstantOrConstantVector(N0.getOperand(1)) && - N0.getNode()->hasOneUse()) { + isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse()) { Sh = N0; Y = N1; } else if (N1.getOpcode() == ISD::SHL && isConstantOrConstantVector(N1.getOperand(1)) && - N1.getNode()->hasOneUse()) { + N1->hasOneUse()) { Sh = N1; Y = N0; } if (Sh.getNode()) { - SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT, Sh.getOperand(0), Y); - return DAG.getNode(ISD::SHL, SDLoc(N), VT, Mul, Sh.getOperand(1)); + SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y); + return DAG.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1)); } } @@ -3999,18 +4090,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { N0.getOpcode() == ISD::ADD && DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) && isMulAddWithConstProfitable(N, N0, N1)) - return DAG.getNode(ISD::ADD, SDLoc(N), VT, - DAG.getNode(ISD::MUL, SDLoc(N0), VT, - N0.getOperand(0), N1), - DAG.getNode(ISD::MUL, SDLoc(N1), VT, - N0.getOperand(1), N1)); + return DAG.getNode( + ISD::ADD, DL, VT, + DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1), + DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1)); // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)). if (N0.getOpcode() == ISD::VSCALE) if (ConstantSDNode *NC1 = isConstOrConstSplat(N1)) { const APInt &C0 = N0.getConstantOperandAPInt(0); const APInt &C1 = NC1->getAPIntValue(); - return DAG.getVScale(SDLoc(N), VT, C0 * C1); + return DAG.getVScale(DL, VT, C0 * C1); } // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)). @@ -4019,7 +4109,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { if (ISD::isConstantSplatVector(N1.getNode(), MulVal)) { const APInt &C0 = N0.getConstantOperandAPInt(0); APInt NewStep = C0 * MulVal; - return DAG.getStepVector(SDLoc(N), VT, NewStep); + return DAG.getStepVector(DL, VT, NewStep); } // Fold ((mul x, 0/undef) -> 0, @@ -4041,7 +4131,6 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) && ISD::matchUnaryPredicate(N1, IsClearMask, /*AllowUndefs*/ true)) { assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector"); - SDLoc DL(N); EVT LegalSVT = N1.getOperand(0).getValueType(); SDValue Zero = DAG.getConstant(0, DL, LegalSVT); SDValue AllOnes = DAG.getAllOnesConstant(DL, LegalSVT); @@ -4054,7 +4143,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { } // reassociate mul - if (SDValue RMUL = reassociateOps(ISD::MUL, SDLoc(N), N0, N1, N->getFlags())) + if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags())) return RMUL; return SDValue(); @@ -4117,7 +4206,7 @@ SDValue DAGCombiner::useDivRem(SDNode *Node) { SDValue Op0 = Node->getOperand(0); SDValue Op1 = Node->getOperand(1); SDValue combined; - for (SDNode *User : Op0.getNode()->uses()) { + for (SDNode *User : Op0->uses()) { if (User == Node || User->getOpcode() == ISD::DELETED_NODE || User->use_empty()) continue; @@ -4257,12 +4346,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) { - SDLoc DL(N); - EVT VT = N->getValueType(0); - EVT CCVT = getSetCCResultType(VT); - unsigned BitWidth = VT.getScalarSizeInBits(); - +static bool isDivisorPowerOfTwo(SDValue Divisor) { // Helper for determining whether a value is a power-2 constant scalar or a // vector of such elements. auto IsPowerOfTwo = [](ConstantSDNode *C) { @@ -4275,11 +4359,20 @@ SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) { return false; }; + return ISD::matchUnaryPredicate(Divisor, IsPowerOfTwo); +} + +SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + EVT CCVT = getSetCCResultType(VT); + unsigned BitWidth = VT.getScalarSizeInBits(); + // fold (sdiv X, pow2) -> simple ops after legalize // FIXME: We check for the exact bit here because the generic lowering gives // better results in that case. The target-specific lowering should learn how // to handle exact sdivs efficiently. - if (!N->getFlags().hasExact() && ISD::matchUnaryPredicate(N1, IsPowerOfTwo)) { + if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1)) { // Target-specific implementation of sdiv x, pow2. if (SDValue Res = BuildSDIVPow2(N)) return Res; @@ -4435,6 +4528,16 @@ SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) { return SDValue(); } +SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) { + if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1) && + !DAG.doesNodeExist(ISD::SDIV, N->getVTList(), {N0, N1})) { + // Target-specific implementation of srem x, pow2. + if (SDValue Res = BuildSREMPow2(N)) + return Res; + } + return SDValue(); +} + // handles ISD::SREM and ISD::UREM SDValue DAGCombiner::visitREM(SDNode *N) { unsigned Opcode = N->getOpcode(); @@ -4451,10 +4554,13 @@ SDValue DAGCombiner::visitREM(SDNode *N) { if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1})) return C; - // fold (urem X, -1) -> select(X == -1, 0, x) - if (!isSigned && N1C && N1C->isAllOnes()) - return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ), - DAG.getConstant(0, DL, VT), N0); + // fold (urem X, -1) -> select(FX == -1, 0, FX) + // Freeze the numerator to avoid a miscompile with an undefined value. + if (!isSigned && N1C && N1C->isAllOnes()) { + SDValue F0 = DAG.getFreeze(N0); + SDValue EqualsNeg1 = DAG.getSetCC(DL, CCVT, F0, N1, ISD::SETEQ); + return DAG.getSelect(DL, VT, EqualsNeg1, DAG.getConstant(0, DL, VT), F0); + } if (SDValue V = simplifyDivRem(N, DAG)) return V; @@ -4495,6 +4601,12 @@ SDValue DAGCombiner::visitREM(SDNode *N) { // combine will not return a DIVREM. Regardless, checking cheapness here // makes sense since the simplification results in fatter code. if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) { + if (isSigned) { + // check if we can build faster implementation for srem + if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N)) + return OptimizedRem; + } + SDValue OptimizedDiv = isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N); if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) { @@ -4654,6 +4766,46 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitAVG(SDNode *N) { + unsigned Opcode = N->getOpcode(); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + // fold (avg c1, c2) + if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1})) + return C; + + // canonicalize constant to RHS. + if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && + !DAG.isConstantIntBuildVectorOrConstantInt(N1)) + return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0); + + if (VT.isVector()) { + if (SDValue FoldedVOp = SimplifyVBinOp(N, DL)) + return FoldedVOp; + + // fold (avgfloor x, 0) -> x >> 1 + if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) { + if (Opcode == ISD::AVGFLOORS) + return DAG.getNode(ISD::SRA, DL, VT, N0, DAG.getConstant(1, DL, VT)); + if (Opcode == ISD::AVGFLOORU) + return DAG.getNode(ISD::SRL, DL, VT, N0, DAG.getConstant(1, DL, VT)); + } + } + + // fold (avg x, undef) -> x + if (N0.isUndef()) + return N1; + if (N1.isUndef()) + return N0; + + // TODO If we use avg for scalars anywhere, we can add (avgfl x, 0) -> x >> 1 + + return SDValue(); +} + /// Perform optimizations common to nodes that compute two values. LoOp and HiOp /// give the opcodes for the two computations that are being performed. Return /// true if a simplification was made. @@ -4812,7 +4964,9 @@ SDValue DAGCombiner::visitMULO(SDNode *N) { DAG.getConstant(0, DL, CarryVT)); // (mulo x, 2) -> (addo x, x) - if (N1C && N1C->getAPIntValue() == 2) + // FIXME: This needs a freeze. + if (N1C && N1C->getAPIntValue() == 2 && + (!IsSigned || VT.getScalarSizeInBits() > 2)) return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL, N->getVTList(), N0, N0); @@ -4869,8 +5023,7 @@ static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2, return 0; const APInt &C1 = N1C->getAPIntValue(); const APInt &C2 = N3C->getAPIntValue(); - if (C1.getBitWidth() < C2.getBitWidth() || - C1 != C2.sextOrSelf(C1.getBitWidth())) + if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(C1.getBitWidth())) return 0; return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0); }; @@ -4977,7 +5130,7 @@ static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2, const APInt &C1 = N1C->getAPIntValue(); const APInt &C3 = N3C->getAPIntValue(); if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() || - C1 != C3.zextOrSelf(C1.getBitWidth())) + C1 != C3.zext(C1.getBitWidth())) return SDValue(); unsigned BW = (C1 + 1).exactLogBase2(); @@ -5007,6 +5160,10 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) { if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1})) return C; + // If the operands are the same, this is a no-op. + if (N0 == N1) + return N0; + // canonicalize constant to RHS if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && !DAG.isConstantIntBuildVectorOrConstantInt(N1)) @@ -5312,29 +5469,27 @@ SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1, } // Turn compare of constants whose difference is 1 bit into add+and+setcc. - // TODO - support non-uniform vector amounts. if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) { // Match a shared variable operand and 2 non-opaque constant operands. - ConstantSDNode *C0 = isConstOrConstSplat(LR); - ConstantSDNode *C1 = isConstOrConstSplat(RR); - if (LL == RL && C0 && C1 && !C0->isOpaque() && !C1->isOpaque()) { + auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) { + // The difference of the constants must be a single bit. const APInt &CMax = APIntOps::umax(C0->getAPIntValue(), C1->getAPIntValue()); const APInt &CMin = APIntOps::umin(C0->getAPIntValue(), C1->getAPIntValue()); - // The difference of the constants must be a single bit. - if ((CMax - CMin).isPowerOf2()) { - // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) --> - // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq - SDValue Max = DAG.getNode(ISD::UMAX, DL, OpVT, LR, RR); - SDValue Min = DAG.getNode(ISD::UMIN, DL, OpVT, LR, RR); - SDValue Offset = DAG.getNode(ISD::SUB, DL, OpVT, LL, Min); - SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, Max, Min); - SDValue Mask = DAG.getNOT(DL, Diff, OpVT); - SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Offset, Mask); - SDValue Zero = DAG.getConstant(0, DL, OpVT); - return DAG.getSetCC(DL, VT, And, Zero, CC0); - } + return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2(); + }; + if (LL == RL && ISD::matchBinaryPredicate(LR, RR, MatchDiffPow2)) { + // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) --> + // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq + SDValue Max = DAG.getNode(ISD::UMAX, DL, OpVT, LR, RR); + SDValue Min = DAG.getNode(ISD::UMIN, DL, OpVT, LR, RR); + SDValue Offset = DAG.getNode(ISD::SUB, DL, OpVT, LL, Min); + SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, Max, Min); + SDValue Mask = DAG.getNOT(DL, Diff, OpVT); + SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Offset, Mask); + SDValue Zero = DAG.getConstant(0, DL, OpVT); + return DAG.getSetCC(DL, VT, And, Zero, CC0); } } } @@ -5836,6 +5991,9 @@ static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) { if (ShiftAmt.uge(VTBitWidth)) return SDValue(); + if (!TLI.hasBitTest(Srl.getOperand(0), Srl.getOperand(1))) + return SDValue(); + // Turn this into a bit-test pattern using mask op + setcc: // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0 SDLoc DL(And); @@ -5882,6 +6040,53 @@ static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG) { return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0), SignMask); } +/// Given a bitwise logic operation N with a matching bitwise logic operand, +/// fold a pattern where 2 of the source operands are identically shifted +/// values. For example: +/// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z +static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp, + SelectionDAG &DAG) { + unsigned LogicOpcode = N->getOpcode(); + assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR || + LogicOpcode == ISD::XOR) + && "Expected bitwise logic operation"); + + if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse()) + return SDValue(); + + // Match another bitwise logic op and a shift. + unsigned ShiftOpcode = ShiftOp.getOpcode(); + if (LogicOp.getOpcode() != LogicOpcode || + !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL || + ShiftOpcode == ISD::SRA)) + return SDValue(); + + // Match another shift op inside the first logic operand. Handle both commuted + // possibilities. + // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z + // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z + SDValue X1 = ShiftOp.getOperand(0); + SDValue Y = ShiftOp.getOperand(1); + SDValue X0, Z; + if (LogicOp.getOperand(0).getOpcode() == ShiftOpcode && + LogicOp.getOperand(0).getOperand(1) == Y) { + X0 = LogicOp.getOperand(0).getOperand(0); + Z = LogicOp.getOperand(1); + } else if (LogicOp.getOperand(1).getOpcode() == ShiftOpcode && + LogicOp.getOperand(1).getOperand(1) == Y) { + X0 = LogicOp.getOperand(1).getOperand(0); + Z = LogicOp.getOperand(0); + } else { + return SDValue(); + } + + EVT VT = N->getValueType(0); + SDLoc DL(N); + SDValue LogicX = DAG.getNode(LogicOpcode, DL, VT, X0, X1); + SDValue NewShift = DAG.getNode(ShiftOpcode, DL, VT, LogicX, Y); + return DAG.getNode(LogicOpcode, DL, VT, NewShift, Z); +} + SDValue DAGCombiner::visitAND(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -5915,27 +6120,25 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) return N0; - // fold (and (masked_load) (build_vec (x, ...))) to zext_masked_load + // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0); - auto *BVec = dyn_cast<BuildVectorSDNode>(N1); - if (MLoad && BVec && MLoad->getExtensionType() == ISD::EXTLOAD && - N0.hasOneUse() && N1.hasOneUse()) { + ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true); + if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && N0.hasOneUse() && + Splat && N1.hasOneUse()) { EVT LoadVT = MLoad->getMemoryVT(); EVT ExtVT = VT; if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) { // For this AND to be a zero extension of the masked load the elements // of the BuildVec must mask the bottom bits of the extended element // type - if (ConstantSDNode *Splat = BVec->getConstantSplatNode()) { - uint64_t ElementSize = - LoadVT.getVectorElementType().getScalarSizeInBits(); - if (Splat->getAPIntValue().isMask(ElementSize)) { - return DAG.getMaskedLoad( - ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(), - MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(), - LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(), - ISD::ZEXTLOAD, MLoad->isExpandingLoad()); - } + uint64_t ElementSize = + LoadVT.getVectorElementType().getScalarSizeInBits(); + if (Splat->getAPIntValue().isMask(ElementSize)) { + return DAG.getMaskedLoad( + ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(), + MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(), + LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(), + ISD::ZEXTLOAD, MLoad->isExpandingLoad()); } } } @@ -6011,7 +6214,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // This can be a pure constant or a vector splat, in which case we treat the // vector as a scalar and use the splat value. APInt Constant = APInt::getZero(1); - if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(N1)) { + if (const ConstantSDNode *C = isConstOrConstSplat(N1)) { Constant = C->getAPIntValue(); } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) { APInt SplatValue, SplatUndef; @@ -6151,6 +6354,11 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (SDValue V = hoistLogicOpWithSameOpcodeHands(N)) return V; + if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG)) + return R; + if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG)) + return R; + // Masking the negated extension of a boolean is just the zero-extended // boolean: // and (sub 0, zext(bool X)), 1 --> zext(bool X) @@ -6209,9 +6417,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N)) return Shifts; - if (TLI.hasBitTest(N0, N1)) - if (SDValue V = combineShiftAnd1ToBitTest(N, DAG)) - return V; + if (SDValue V = combineShiftAnd1ToBitTest(N, DAG)) + return V; // Recognize the following pattern: // @@ -6261,11 +6468,11 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, bool LookPassAnd0 = false; bool LookPassAnd1 = false; if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL) - std::swap(N0, N1); + std::swap(N0, N1); if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL) - std::swap(N0, N1); + std::swap(N0, N1); if (N0.getOpcode() == ISD::AND) { - if (!N0.getNode()->hasOneUse()) + if (!N0->hasOneUse()) return SDValue(); ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1)); // Also handle 0xffff since the LHS is guaranteed to have zeros there. @@ -6278,7 +6485,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, } if (N1.getOpcode() == ISD::AND) { - if (!N1.getNode()->hasOneUse()) + if (!N1->hasOneUse()) return SDValue(); ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1)); if (!N11C || N11C->getZExtValue() != 0xFF) @@ -6291,7 +6498,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, std::swap(N0, N1); if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL) return SDValue(); - if (!N0.getNode()->hasOneUse() || !N1.getNode()->hasOneUse()) + if (!N0->hasOneUse() || !N1->hasOneUse()) return SDValue(); ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1)); @@ -6304,7 +6511,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8) SDValue N00 = N0->getOperand(0); if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) { - if (!N00.getNode()->hasOneUse()) + if (!N00->hasOneUse()) return SDValue(); ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1)); if (!N001C || N001C->getZExtValue() != 0xFF) @@ -6315,7 +6522,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, SDValue N10 = N1->getOperand(0); if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) { - if (!N10.getNode()->hasOneUse()) + if (!N10->hasOneUse()) return SDValue(); ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1)); // Also allow 0xFFFF since the bits will be shifted out. This is needed @@ -6333,19 +6540,23 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, // Make sure everything beyond the low halfword gets set to zero since the SRL // 16 will clear the top bits. unsigned OpSizeInBits = VT.getSizeInBits(); - if (DemandHighBits && OpSizeInBits > 16) { + if (OpSizeInBits > 16) { // If the left-shift isn't masked out then the only way this is a bswap is // if all bits beyond the low 8 are 0. In that case the entire pattern // reduces to a left shift anyway: leave it for other parts of the combiner. - if (!LookPassAnd0) + if (DemandHighBits && !LookPassAnd0) return SDValue(); // However, if the right shift isn't masked out then it might be because - // it's not needed. See if we can spot that too. - if (!LookPassAnd1 && - !DAG.MaskedValueIsZero( - N10, APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - 16))) - return SDValue(); + // it's not needed. See if we can spot that too. If the high bits aren't + // demanded, we only need bits 23:16 to be zero. Otherwise, we need all + // upper bits to be zero. + if (!LookPassAnd1) { + unsigned HighBit = DemandHighBits ? OpSizeInBits : 24; + if (!DAG.MaskedValueIsZero(N10, + APInt::getBitsSet(OpSizeInBits, 16, HighBit))) + return SDValue(); + } } SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00); @@ -6365,7 +6576,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, /// ((x & 0x00ff0000) << 8) | /// ((x & 0xff000000) >> 8) static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) { - if (!N.getNode()->hasOneUse()) + if (!N->hasOneUse()) return false; unsigned Opc = N.getOpcode(); @@ -6552,8 +6763,9 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) { if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) && !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts))) return SDValue(); - } else + } else { return SDValue(); + } // Make sure the parts are all coming from the same node. if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3]) @@ -6591,7 +6803,7 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) { // (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible. if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND && // Don't increase # computations. - (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) { + (N0->hasOneUse() || N1->hasOneUse())) { // We can only do this xform if we know that bits from X that are set in C2 // but not in C1 are already zero. Likewise for Y. if (const ConstantSDNode *N0O1C = @@ -6619,7 +6831,7 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) { N1.getOpcode() == ISD::AND && N0.getOperand(0) == N1.getOperand(0) && // Don't increase # computations. - (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) { + (N0->hasOneUse() || N1->hasOneUse())) { SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(1), N1.getOperand(1)); return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X); @@ -6634,14 +6846,38 @@ static SDValue visitORCommutative( EVT VT = N0.getValueType(); if (N0.getOpcode() == ISD::AND) { // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y) - if (isBitwiseNot(N0.getOperand(1)) && N0.getOperand(1).getOperand(0) == N1) + // TODO: Set AllowUndefs = true. + if (getBitwiseNotOperand(N0.getOperand(1), N0.getOperand(0), + /* AllowUndefs */ false) == N1) return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(0), N1); // fold (or (and (xor Y, -1), X), Y) -> (or X, Y) - if (isBitwiseNot(N0.getOperand(0)) && N0.getOperand(0).getOperand(0) == N1) + if (getBitwiseNotOperand(N0.getOperand(0), N0.getOperand(1), + /* AllowUndefs */ false) == N1) return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(1), N1); } + if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG)) + return R; + + auto peekThroughZext = [](SDValue V) { + if (V->getOpcode() == ISD::ZERO_EXTEND) + return V->getOperand(0); + return V; + }; + + // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y + if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL && + N0.getOperand(0) == N1.getOperand(0) && + peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1))) + return N0; + + // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y + if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL && + N0.getOperand(1) == N1.getOperand(0) && + peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1))) + return N0; + return SDValue(); } @@ -6678,11 +6914,10 @@ SDValue DAGCombiner::visitOR(SDNode *N) { return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType()); // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask) - // Do this only if the resulting shuffle is legal. - if (isa<ShuffleVectorSDNode>(N0) && - isa<ShuffleVectorSDNode>(N1) && - // Avoid folding a node with illegal type. - TLI.isTypeLegal(VT)) { + // Do this only if the resulting type / shuffle is legal. + auto *SV0 = dyn_cast<ShuffleVectorSDNode>(N0); + auto *SV1 = dyn_cast<ShuffleVectorSDNode>(N1); + if (SV0 && SV1 && TLI.isTypeLegal(VT)) { bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode()); bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode()); bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode()); @@ -6691,11 +6926,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) { if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) { assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!"); assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!"); - const ShuffleVectorSDNode *SV0 = cast<ShuffleVectorSDNode>(N0); - const ShuffleVectorSDNode *SV1 = cast<ShuffleVectorSDNode>(N1); bool CanFold = true; int NumElts = VT.getVectorNumElements(); - SmallVector<int, 4> Mask(NumElts); + SmallVector<int, 4> Mask(NumElts, -1); for (int i = 0; i != NumElts; ++i) { int M0 = SV0->getMaskElt(i); @@ -6707,10 +6940,8 @@ SDValue DAGCombiner::visitOR(SDNode *N) { // If one element is zero and the otherside is undef, keep undef. // This also handles the case that both are undef. - if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0)) { - Mask[i] = -1; + if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0)) continue; - } // Make sure only one of the elements is zero. if (M0Zero == M1Zero) { @@ -6778,7 +7009,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) { auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) { return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue()); }; - if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() && + if (N0.getOpcode() == ISD::AND && N0->hasOneUse() && ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) { if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT, {N1, N0.getOperand(1)})) { @@ -7098,8 +7329,9 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize, // Neg with outer conversions stripped away. SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg, SDValue InnerPos, - SDValue InnerNeg, unsigned PosOpcode, - unsigned NegOpcode, const SDLoc &DL) { + SDValue InnerNeg, bool HasPos, + unsigned PosOpcode, unsigned NegOpcode, + const SDLoc &DL) { // fold (or (shl x, (*ext y)), // (srl x, (*ext (sub 32, y)))) -> // (rotl x, y) or (rotr x, (sub 32, y)) @@ -7110,7 +7342,6 @@ SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, EVT VT = Shifted.getValueType(); if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG, /*IsRotate*/ true)) { - bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT); return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted, HasPos ? Pos : Neg); } @@ -7126,8 +7357,9 @@ SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, // TODO: Merge with MatchRotatePosNeg. SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg, SDValue InnerPos, - SDValue InnerNeg, unsigned PosOpcode, - unsigned NegOpcode, const SDLoc &DL) { + SDValue InnerNeg, bool HasPos, + unsigned PosOpcode, unsigned NegOpcode, + const SDLoc &DL) { EVT VT = N0.getValueType(); unsigned EltBits = VT.getScalarSizeInBits(); @@ -7139,7 +7371,6 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, // (srl x1, (*ext y))) -> // (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y)) if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) { - bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT); return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1, HasPos ? Pos : Neg); } @@ -7201,6 +7432,16 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { bool HasROTR = hasOperation(ISD::ROTR, VT); bool HasFSHL = hasOperation(ISD::FSHL, VT); bool HasFSHR = hasOperation(ISD::FSHR, VT); + + // If the type is going to be promoted and the target has enabled custom + // lowering for rotate, allow matching rotate by non-constants. Only allow + // this for scalar types. + if (VT.isScalarInteger() && TLI.getTypeAction(*DAG.getContext(), VT) == + TargetLowering::TypePromoteInteger) { + HasROTL |= TLI.getOperationAction(ISD::ROTL, VT) == TargetLowering::Custom; + HasROTR |= TLI.getOperationAction(ISD::ROTR, VT) == TargetLowering::Custom; + } + if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR) return SDValue(); @@ -7254,11 +7495,6 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { if (LHSShift.getOpcode() == RHSShift.getOpcode()) return SDValue(); // Shifts must disagree. - // TODO: Support pre-legalization funnel-shift by constant. - bool IsRotate = LHSShift.getOperand(0) == RHSShift.getOperand(0); - if (!IsRotate && !(HasFSHL || HasFSHR)) - return SDValue(); // Requires funnel shift support. - // Canonicalize shl to left side in a shl/srl pair. if (RHSShift.getOpcode() == ISD::SHL) { std::swap(LHS, RHS); @@ -7272,27 +7508,12 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { SDValue RHSShiftArg = RHSShift.getOperand(0); SDValue RHSShiftAmt = RHSShift.getOperand(1); - // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1) - // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2) - // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1) - // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2) - // iff C1+C2 == EltSizeInBits auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS, ConstantSDNode *RHS) { return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits; }; - if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) { - SDValue Res; - if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) { - bool UseROTL = !LegalOperations || HasROTL; - Res = DAG.getNode(UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg, - UseROTL ? LHSShiftAmt : RHSShiftAmt); - } else { - bool UseFSHL = !LegalOperations || HasFSHL; - Res = DAG.getNode(UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg, - RHSShiftArg, UseFSHL ? LHSShiftAmt : RHSShiftAmt); - } + auto ApplyMasks = [&](SDValue Res) { // If there is an AND of either shifted operand, apply it to the result. if (LHSMask.getNode() || RHSMask.getNode()) { SDValue AllOnes = DAG.getAllOnesConstant(DL, VT); @@ -7313,6 +7534,71 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { } return Res; + }; + + // TODO: Support pre-legalization funnel-shift by constant. + bool IsRotate = LHSShift.getOperand(0) == RHSShift.getOperand(0); + if (!IsRotate && !(HasFSHL || HasFSHR)) { + if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() && + ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) { + // Look for a disguised rotate by constant. + // The common shifted operand X may be hidden inside another 'or'. + SDValue X, Y; + auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) { + if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR) + return false; + if (CommonOp == Or.getOperand(0)) { + X = CommonOp; + Y = Or.getOperand(1); + return true; + } + if (CommonOp == Or.getOperand(1)) { + X = CommonOp; + Y = Or.getOperand(0); + return true; + } + return false; + }; + + SDValue Res; + if (matchOr(LHSShiftArg, RHSShiftArg)) { + // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1) + SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt); + SDValue ShlY = DAG.getNode(ISD::SHL, DL, VT, Y, LHSShiftAmt); + Res = DAG.getNode(ISD::OR, DL, VT, RotX, ShlY); + } else if (matchOr(RHSShiftArg, LHSShiftArg)) { + // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2) + SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt); + SDValue SrlY = DAG.getNode(ISD::SRL, DL, VT, Y, RHSShiftAmt); + Res = DAG.getNode(ISD::OR, DL, VT, RotX, SrlY); + } else { + return SDValue(); + } + + return ApplyMasks(Res); + } + + return SDValue(); // Requires funnel shift support. + } + + // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1) + // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2) + // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1) + // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2) + // iff C1+C2 == EltSizeInBits + if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) { + SDValue Res; + if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) { + bool UseROTL = !LegalOperations || HasROTL; + Res = DAG.getNode(UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg, + UseROTL ? LHSShiftAmt : RHSShiftAmt); + } else { + bool UseFSHL = !LegalOperations || HasFSHL; + Res = DAG.getNode(UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg, + RHSShiftArg, UseFSHL ? LHSShiftAmt : RHSShiftAmt); + } + + return ApplyMasks(Res); } // Even pre-legalization, we can't easily rotate/funnel-shift by a variable @@ -7343,26 +7629,26 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) { if (IsRotate && (HasROTL || HasROTR)) { SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0, - RExtOp0, ISD::ROTL, ISD::ROTR, DL); + RExtOp0, HasROTL, ISD::ROTL, ISD::ROTR, DL); if (TryL) return TryL; SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0, - LExtOp0, ISD::ROTR, ISD::ROTL, DL); + LExtOp0, HasROTR, ISD::ROTR, ISD::ROTL, DL); if (TryR) return TryR; } SDValue TryL = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt, - LExtOp0, RExtOp0, ISD::FSHL, ISD::FSHR, DL); + LExtOp0, RExtOp0, HasFSHL, ISD::FSHL, ISD::FSHR, DL); if (TryL) return TryL; SDValue TryR = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt, - RExtOp0, LExtOp0, ISD::FSHR, ISD::FSHL, DL); + RExtOp0, LExtOp0, HasFSHR, ISD::FSHR, ISD::FSHL, DL); if (TryR) return TryR; @@ -7877,7 +8163,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { // little endian value load Optional<bool> IsBigEndian = isBigEndian( makeArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset); - if (!IsBigEndian.hasValue()) + if (!IsBigEndian) return SDValue(); assert(FirstByteProvider && "must be set"); @@ -8084,6 +8370,13 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags())) return RXOR; + // look for 'add-like' folds: + // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE) + if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) && + isMinSignedConstant(N1)) + if (SDValue Combined = visitADDLike(N)) + return Combined; + // fold !(x cc y) -> (x !cc y) unsigned N0Opcode = N0.getOpcode(); SDValue LHS, RHS, CC; @@ -8249,6 +8542,11 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { if (SDValue V = hoistLogicOpWithSameOpcodeHands(N)) return V; + if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG)) + return R; + if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG)) + return R; + // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable if (SDValue MM = unfoldMaskedMerge(N)) return MM; @@ -8479,7 +8777,9 @@ SDValue DAGCombiner::visitRotate(SDNode *N) { } unsigned NextOp = N0.getOpcode(); - // fold (rot* (rot* x, c2), c1) -> (rot* x, c1 +- c2 % bitsize) + + // fold (rot* (rot* x, c2), c1) + // -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize)) % bitsize) if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) { SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1); SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)); @@ -8487,14 +8787,19 @@ SDValue DAGCombiner::visitRotate(SDNode *N) { EVT ShiftVT = C1->getValueType(0); bool SameSide = (N->getOpcode() == NextOp); unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB; - if (SDValue CombinedShift = DAG.FoldConstantArithmetic( - CombineOp, dl, ShiftVT, {N1, N0.getOperand(1)})) { - SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT); - SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic( - ISD::SREM, dl, ShiftVT, {CombinedShift, BitsizeC}); - return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0), - CombinedShiftNorm); - } + SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT); + SDValue Norm1 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT, + {N1, BitsizeC}); + SDValue Norm2 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT, + {N0.getOperand(1), BitsizeC}); + if (Norm1 && Norm2) + if (SDValue CombinedShift = DAG.FoldConstantArithmetic( + CombineOp, dl, ShiftVT, {Norm1, Norm2})) { + SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic( + ISD::UREM, dl, ShiftVT, {CombinedShift, BitsizeC}); + return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0), + CombinedShiftNorm); + } } } return SDValue(); @@ -8654,52 +8959,63 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { } } - // fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2 - // fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 > C2 - // TODO - support non-uniform vector shift amounts. - ConstantSDNode *N1C = isConstOrConstSplat(N1); - if (N1C && (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) && - N0->getFlags().hasExact()) { - if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { - uint64_t C1 = N0C1->getZExtValue(); - uint64_t C2 = N1C->getZExtValue(); - SDLoc DL(N); - if (C1 <= C2) - return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), - DAG.getConstant(C2 - C1, DL, ShiftVT)); - return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), - DAG.getConstant(C1 - C2, DL, ShiftVT)); + if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) { + auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS, + ConstantSDNode *RHS) { + const APInt &LHSC = LHS->getAPIntValue(); + const APInt &RHSC = RHS->getAPIntValue(); + return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) && + LHSC.getZExtValue() <= RHSC.getZExtValue(); + }; + + SDLoc DL(N); + + // fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2 + // fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2 + if (N0->getFlags().hasExact()) { + if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount, + /*AllowUndefs*/ false, + /*AllowTypeMismatch*/ true)) { + SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT); + SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01); + return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff); + } + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount, + /*AllowUndefs*/ false, + /*AllowTypeMismatch*/ true)) { + SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT); + SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1); + return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Diff); + } } - } - // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or - // (and (srl x, (sub c1, c2), MASK) - // Only fold this if the inner shift has no other uses -- if it does, folding - // this will increase the total number of instructions. - // TODO - drop hasOneUse requirement if c1 == c2? - // TODO - support non-uniform vector shift amounts. - if (N1C && N0.getOpcode() == ISD::SRL && N0.hasOneUse() && - TLI.shouldFoldConstantShiftPairToMask(N, Level)) { - if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { - if (N0C1->getAPIntValue().ult(OpSizeInBits)) { - uint64_t c1 = N0C1->getZExtValue(); - uint64_t c2 = N1C->getZExtValue(); - APInt Mask = APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - c1); - SDValue Shift; - if (c2 > c1) { - Mask <<= c2 - c1; - SDLoc DL(N); - Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), - DAG.getConstant(c2 - c1, DL, ShiftVT)); - } else { - Mask.lshrInPlace(c1 - c2); - SDLoc DL(N); - Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), - DAG.getConstant(c1 - c2, DL, ShiftVT)); - } - SDLoc DL(N0); - return DAG.getNode(ISD::AND, DL, VT, Shift, - DAG.getConstant(Mask, DL, VT)); + // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or + // (and (srl x, (sub c1, c2), MASK) + // Only fold this if the inner shift has no other uses -- if it does, + // folding this will increase the total number of instructions. + if (N0.getOpcode() == ISD::SRL && + (N0.getOperand(1) == N1 || N0.hasOneUse()) && + TLI.shouldFoldConstantShiftPairToMask(N, Level)) { + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount, + /*AllowUndefs*/ false, + /*AllowTypeMismatch*/ true)) { + SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT); + SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1); + SDValue Mask = DAG.getAllOnesConstant(DL, VT); + Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N01); + Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, Diff); + SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff); + return DAG.getNode(ISD::AND, DL, VT, Shift, Mask); + } + if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount, + /*AllowUndefs*/ false, + /*AllowTypeMismatch*/ true)) { + SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT); + SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01); + SDValue Mask = DAG.getAllOnesConstant(DL, VT); + Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N1); + SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff); + return DAG.getNode(ISD::AND, DL, VT, Shift, Mask); } } } @@ -8718,7 +9034,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // Variant of version done on multiply, except mul by a power of 2 is turned // into a shift. if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) && - N0.getNode()->hasOneUse() && + N0->hasOneUse() && isConstantOrConstantVector(N1, /* No Opaques */ true) && isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) && TLI.isDesirableToCommuteWithShift(N, Level)) { @@ -8730,14 +9046,14 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { } // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2) - if (N0.getOpcode() == ISD::MUL && N0.getNode()->hasOneUse() && - isConstantOrConstantVector(N1, /* No Opaques */ true) && - isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true)) { - SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1); - if (isConstantOrConstantVector(Shl)) + if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) { + SDValue N01 = N0.getOperand(1); + if (SDValue Shl = + DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1})) return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl); } + ConstantSDNode *N1C = isConstOrConstSplat(N1); if (N1C && !N1C->isOpaque()) if (SDValue NewSHL = visitShiftByConstant(N)) return NewSHL; @@ -9023,8 +9339,10 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits; if (LargeShift->getAPIntValue() == TruncBits) { SDLoc DL(N); - SDValue Amt = DAG.getConstant(N1C->getZExtValue() + TruncBits, DL, - getShiftAmountTy(LargeVT)); + EVT LargeShiftVT = getShiftAmountTy(LargeVT); + SDValue Amt = DAG.getZExtOrTrunc(N1, DL, LargeShiftVT); + Amt = DAG.getNode(ISD::ADD, DL, LargeShiftVT, Amt, + DAG.getConstant(TruncBits, DL, LargeShiftVT)); SDValue SRA = DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt); return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA); @@ -9063,6 +9381,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { return V; EVT VT = N0.getValueType(); + EVT ShiftVT = N1.getValueType(); unsigned OpSizeInBits = VT.getScalarSizeInBits(); // fold (srl c1, c2) -> c1 >>u c2 @@ -9104,7 +9423,6 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { }; if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { SDLoc DL(N); - EVT ShiftVT = N1.getValueType(); SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum); } @@ -9148,15 +9466,41 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { } } - // fold (srl (shl x, c), c) -> (and x, cst2) - // TODO - (srl (shl x, c1), c2). - if (N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 && - isConstantOrConstantVector(N1, /* NoOpaques */ true)) { - SDLoc DL(N); - SDValue Mask = - DAG.getNode(ISD::SRL, DL, VT, DAG.getAllOnesConstant(DL, VT), N1); - AddToWorklist(Mask.getNode()); - return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), Mask); + // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or + // (and (srl x, (sub c2, c1), MASK) + if (N0.getOpcode() == ISD::SHL && + (N0.getOperand(1) == N1 || N0->hasOneUse()) && + TLI.shouldFoldConstantShiftPairToMask(N, Level)) { + auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS, + ConstantSDNode *RHS) { + const APInt &LHSC = LHS->getAPIntValue(); + const APInt &RHSC = RHS->getAPIntValue(); + return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) && + LHSC.getZExtValue() <= RHSC.getZExtValue(); + }; + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount, + /*AllowUndefs*/ false, + /*AllowTypeMismatch*/ true)) { + SDLoc DL(N); + SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT); + SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1); + SDValue Mask = DAG.getAllOnesConstant(DL, VT); + Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N01); + Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, Diff); + SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff); + return DAG.getNode(ISD::AND, DL, VT, Shift, Mask); + } + if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount, + /*AllowUndefs*/ false, + /*AllowTypeMismatch*/ true)) { + SDLoc DL(N); + SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT); + SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01); + SDValue Mask = DAG.getAllOnesConstant(DL, VT); + Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N1); + SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff); + return DAG.getNode(ISD::AND, DL, VT, Shift, Mask); + } } // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask) @@ -9412,6 +9756,21 @@ SDValue DAGCombiner::visitSHLSAT(SDNode *N) { DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, {N0, N1})) return C; + ConstantSDNode *N1C = isConstOrConstSplat(N1); + + if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) { + // fold (sshlsat x, c) -> (shl x, c) + if (N->getOpcode() == ISD::SSHLSAT && N1C && + N1C->getAPIntValue().ult(DAG.ComputeNumSignBits(N0))) + return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1); + + // fold (ushlsat x, c) -> (shl x, c) + if (N->getOpcode() == ISD::USHLSAT && N1C && + N1C->getAPIntValue().ule( + DAG.computeKnownBits(N0).countMinLeadingZeros())) + return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1); + } + return SDValue(); } @@ -9435,18 +9794,27 @@ static SDValue combineABSToABD(SDNode *N, SelectionDAG &DAG, (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND)) return SDValue(); + EVT VT = N->getValueType(0); EVT VT1 = Op0.getOperand(0).getValueType(); EVT VT2 = Op1.getOperand(0).getValueType(); - // Check if the operands are of same type and valid size. unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU; - if (VT1 != VT2 || !TLI.isOperationLegalOrCustom(ABDOpcode, VT1)) - return SDValue(); - Op0 = Op0.getOperand(0); - Op1 = Op1.getOperand(0); - SDValue ABD = - DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1); - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD); + // fold abs(sext(x) - sext(y)) -> zext(abds(x, y)) + // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y)) + // NOTE: Extensions must be equivalent. + if (VT1 == VT2 && TLI.isOperationLegalOrCustom(ABDOpcode, VT1)) { + Op0 = Op0.getOperand(0); + Op1 = Op1.getOperand(0); + SDValue ABD = DAG.getNode(ABDOpcode, SDLoc(N), VT1, Op0, Op1); + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, ABD); + } + + // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y)) + // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y)) + if (TLI.isOperationLegalOrCustom(ABDOpcode, VT)) + return DAG.getNode(ABDOpcode, SDLoc(N), VT, Op0, Op1); + + return SDValue(); } SDValue DAGCombiner::visitABS(SDNode *N) { @@ -9472,24 +9840,60 @@ SDValue DAGCombiner::visitABS(SDNode *N) { SDValue DAGCombiner::visitBSWAP(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); + SDLoc DL(N); // fold (bswap c1) -> c2 if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) - return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N0); + return DAG.getNode(ISD::BSWAP, DL, VT, N0); // fold (bswap (bswap x)) -> x if (N0.getOpcode() == ISD::BSWAP) - return N0->getOperand(0); + return N0.getOperand(0); // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse // isn't supported, it will be expanded to bswap followed by a manual reversal // of bits in each byte. By placing bswaps before bitreverse, we can remove // the two bswaps if the bitreverse gets expanded. if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) { - SDLoc DL(N); SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0)); return DAG.getNode(ISD::BITREVERSE, DL, VT, BSwap); } + // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2)))))) + // iff x >= bw/2 (i.e. lower half is known zero) + unsigned BW = VT.getScalarSizeInBits(); + if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) { + auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)); + EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), BW / 2); + if (ShAmt && ShAmt->getAPIntValue().ult(BW) && + ShAmt->getZExtValue() >= (BW / 2) && + (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(HalfVT) && + TLI.isTruncateFree(VT, HalfVT) && + (!LegalOperations || hasOperation(ISD::BSWAP, HalfVT))) { + SDValue Res = N0.getOperand(0); + if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2))) + Res = DAG.getNode(ISD::SHL, DL, VT, Res, + DAG.getConstant(NewShAmt, DL, getShiftAmountTy(VT))); + Res = DAG.getZExtOrTrunc(Res, DL, HalfVT); + Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res); + return DAG.getZExtOrTrunc(Res, DL, VT); + } + } + + // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as + // inverse-shift-of-bswap: + // bswap (X u<< C) --> (bswap X) u>> C + // bswap (X u>> C) --> (bswap X) u<< C + if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) && + N0.hasOneUse()) { + auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)); + if (ShAmt && ShAmt->getAPIntValue().ult(BW) && + ShAmt->getZExtValue() % 8 == 0) { + SDValue NewSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0)); + unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL; + return DAG.getNode(InverseShift, DL, VT, NewSwap, N0.getOperand(1)); + } + } + return SDValue(); } @@ -9740,7 +10144,8 @@ SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) { if (C1Val.isPowerOf2() && C2Val.isZero()) { if (VT != MVT::i1) Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond); - SDValue ShAmtC = DAG.getConstant(C1Val.exactLogBase2(), DL, VT); + SDValue ShAmtC = + DAG.getShiftAmountConstant(C1Val.exactLogBase2(), VT, DL); return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC); } @@ -10023,7 +10428,7 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) { // Any flags available in a select/setcc fold will be on the setcc as they // migrated from fcmp - Flags = N0.getNode()->getFlags(); + Flags = N0->getFlags(); SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1, N2, N0.getOperand(2)); SelectNode->setFlags(Flags); @@ -10096,14 +10501,19 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) { TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1)); } -bool refineUniformBase(SDValue &BasePtr, SDValue &Index, SelectionDAG &DAG) { +bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled, + SelectionDAG &DAG) { if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD) return false; + // Only perform the transformation when existing operands can be reused. + if (IndexIsScaled) + return false; + // For now we check only the LHS of the add. SDValue LHS = Index.getOperand(0); SDValue SplatVal = DAG.getSplatValue(LHS); - if (!SplatVal) + if (!SplatVal || SplatVal.getValueType() != BasePtr.getValueType()) return false; BasePtr = SplatVal; @@ -10112,23 +10522,29 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, SelectionDAG &DAG) { } // Fold sext/zext of index into index type. -bool refineIndexType(MaskedGatherScatterSDNode *MGS, SDValue &Index, - bool Scaled, SelectionDAG &DAG) { +bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT, + SelectionDAG &DAG) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + // It's always safe to look through zero extends. if (Index.getOpcode() == ISD::ZERO_EXTEND) { SDValue Op = Index.getOperand(0); - MGS->setIndexType(Scaled ? ISD::UNSIGNED_SCALED : ISD::UNSIGNED_UNSCALED); - if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) { + if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType(), DataVT)) { + IndexType = ISD::UNSIGNED_SCALED; Index = Op; return true; } + if (ISD::isIndexTypeSigned(IndexType)) { + IndexType = ISD::UNSIGNED_SCALED; + return true; + } } - if (Index.getOpcode() == ISD::SIGN_EXTEND) { + // It's only safe to look through sign extends when Index is signed. + if (Index.getOpcode() == ISD::SIGN_EXTEND && + ISD::isIndexTypeSigned(IndexType)) { SDValue Op = Index.getOperand(0); - MGS->setIndexType(Scaled ? ISD::SIGNED_SCALED : ISD::SIGNED_UNSCALED); - if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) { + if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType(), DataVT)) { Index = Op; return true; } @@ -10145,24 +10561,25 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) { SDValue Scale = MSC->getScale(); SDValue StoreVal = MSC->getValue(); SDValue BasePtr = MSC->getBasePtr(); + ISD::MemIndexType IndexType = MSC->getIndexType(); SDLoc DL(N); // Zap scatters with a zero mask. if (ISD::isConstantSplatVectorAllZeros(Mask.getNode())) return Chain; - if (refineUniformBase(BasePtr, Index, DAG)) { + if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) { SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedScatter( - DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops, - MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore()); + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), + DL, Ops, MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); } - if (refineIndexType(MSC, Index, MSC->isIndexScaled(), DAG)) { + if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) { SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedScatter( - DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops, - MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore()); + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), + DL, Ops, MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); } return SDValue(); @@ -10217,7 +10634,7 @@ SDValue DAGCombiner::visitMSTORE(SDNode *N) { // If this is a TRUNC followed by a masked store, fold this into a masked // truncating store. We can do this even if this is already a masked // truncstore. - if ((Value.getOpcode() == ISD::TRUNCATE) && Value.getNode()->hasOneUse() && + if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() && MST->isUnindexed() && TLI.canCombineTruncStore(Value.getOperand(0).getValueType(), MST->getMemoryVT(), LegalOperations)) { @@ -10240,26 +10657,25 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) { SDValue Scale = MGT->getScale(); SDValue PassThru = MGT->getPassThru(); SDValue BasePtr = MGT->getBasePtr(); + ISD::MemIndexType IndexType = MGT->getIndexType(); SDLoc DL(N); // Zap gathers with a zero mask. if (ISD::isConstantSplatVectorAllZeros(Mask.getNode())) return CombineTo(N, PassThru, MGT->getChain()); - if (refineUniformBase(BasePtr, Index, DAG)) { + if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) { SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), - MGT->getMemoryVT(), DL, Ops, - MGT->getMemOperand(), MGT->getIndexType(), - MGT->getExtensionType()); + return DAG.getMaskedGather( + DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, + Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType()); } - if (refineIndexType(MGT, Index, MGT->isIndexScaled(), DAG)) { + if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) { SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), - MGT->getMemoryVT(), DL, Ops, - MGT->getMemOperand(), MGT->getIndexType(), - MGT->getExtensionType()); + return DAG.getMaskedGather( + DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, + Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType()); } return SDValue(); @@ -10513,23 +10929,25 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { Other = N1; } + // zext(x) >= y ? trunc(zext(x) - y) : 0 + // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit))) + // zext(x) > y ? trunc(zext(x) - y) : 0 + // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit))) + if (Other && Other.getOpcode() == ISD::TRUNCATE && + Other.getOperand(0).getOpcode() == ISD::SUB && + (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) { + SDValue OpLHS = Other.getOperand(0).getOperand(0); + SDValue OpRHS = Other.getOperand(0).getOperand(1); + if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND) + if (SDValue R = getTruncatedUSUBSAT(VT, LHS.getValueType(), LHS, RHS, + DAG, DL)) + return R; + } + if (Other && Other.getNumOperands() == 2) { SDValue CondRHS = RHS; SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1); - if (Other.getOpcode() == ISD::SUB && - LHS.getOpcode() == ISD::ZERO_EXTEND && LHS.getOperand(0) == OpLHS && - OpRHS.getOpcode() == ISD::TRUNCATE && OpRHS.getOperand(0) == RHS) { - // Look for a general sub with unsigned saturation first. - // zext(x) >= y ? x - trunc(y) : 0 - // --> usubsat(x,trunc(umin(y,SatLimit))) - // zext(x) > y ? x - trunc(y) : 0 - // --> usubsat(x,trunc(umin(y,SatLimit))) - if (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) - return getTruncatedUSUBSAT(VT, LHS.getValueType(), LHS, RHS, DAG, - DL); - } - if (OpLHS == LHS) { // Look for a general sub with unsigned saturation first. // x >= y ? x-y : 0 --> usubsat x, y @@ -10560,8 +10978,8 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { // Another special case: If C was a sign bit, the sub has been // canonicalized into a xor. - // FIXME: Would it be better to use computeKnownBits to determine - // whether it's safe to decanonicalize the xor? + // FIXME: Would it be better to use computeKnownBits to + // determine whether it's safe to decanonicalize the xor? // x s< 0 ? x^C : 0 --> usubsat x, C APInt SplatValue; if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR && @@ -10627,17 +11045,18 @@ SDValue DAGCombiner::visitSELECT_CC(SDNode *N) { CC, SDLoc(N), false)) { AddToWorklist(SCC.getNode()); - if (ConstantSDNode *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode())) { - if (!SCCC->isZero()) - return N2; // cond always true -> true val - else - return N3; // cond always false -> false val - } else if (SCC->isUndef()) { - // When the condition is UNDEF, just return the first operand. This is - // coherent the DAG creation, no setcc node is created in this case + // cond always true -> true val + // cond always false -> false val + if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode())) + return SCCC->isZero() ? N3 : N2; + + // When the condition is UNDEF, just return the first operand. This is + // coherent the DAG creation, no setcc node is created in this case + if (SCC->isUndef()) return N2; - } else if (SCC.getOpcode() == ISD::SETCC) { - // Fold to a simpler select_cc + + // Fold to a simpler select_cc + if (SCC.getOpcode() == ISD::SETCC) { SDValue SelectOp = DAG.getNode( ISD::SELECT_CC, SDLoc(N), N2.getValueType(), SCC.getOperand(0), SCC.getOperand(1), N2, N3, SCC.getOperand(2)); @@ -10920,9 +11339,8 @@ static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0, const TargetLowering &TLI) { bool HasCopyToRegUses = false; bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType()); - for (SDNode::use_iterator UI = N0.getNode()->use_begin(), - UE = N0.getNode()->use_end(); - UI != UE; ++UI) { + for (SDNode::use_iterator UI = N0->use_begin(), UE = N0->use_end(); UI != UE; + ++UI) { SDNode *User = *UI; if (User == N) continue; @@ -11254,9 +11672,12 @@ static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner, bool LegalOperations, SDNode *N, SDValue N0, ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) { + // TODO: isFixedLengthVector() should be removed and any negative effects on + // code generation being the result of that target's implementation of + // isVectorLoadExtDesirable(). if (!ISD::isNON_EXTLoad(N0.getNode()) || !ISD::isUNINDEXEDLoad(N0.getNode()) || - ((LegalOperations || VT.isVector() || + ((LegalOperations || VT.isFixedLengthVector() || !cast<LoadSDNode>(N0)->isSimple()) && !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType()))) return {}; @@ -11480,6 +11901,10 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); + // sext(undef) = 0 because the top bit will all be the same. + if (N0.isUndef()) + return DAG.getConstant(0, DL, VT); + if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) return Res; @@ -11649,10 +12074,10 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { // Return SDValue here as the xor should have already been replaced in // this sext. return SDValue(); - } else { - // Return a new sext with the new xor. - return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor); } + + // Return a new sext with the new xor. + return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor); } SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)); @@ -11725,6 +12150,10 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); + // zext(undef) = 0 + if (N0.isUndef()) + return DAG.getConstant(0, SDLoc(N), VT); + if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) return Res; @@ -11984,6 +12413,10 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); + // aext(undef) = undef + if (N0.isUndef()) + return DAG.getUNDEF(VT); + if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes)) return Res; @@ -12021,11 +12454,10 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { !TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(), N0.getValueType())) { SDLoc DL(N); - SDValue X = N0.getOperand(0).getOperand(0); - X = DAG.getAnyExtOrTrunc(X, DL, VT); - APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits()); - return DAG.getNode(ISD::AND, DL, VT, - X, DAG.getConstant(Mask, DL, VT)); + SDValue X = DAG.getAnyExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT); + SDValue Y = DAG.getNode(ISD::ANY_EXTEND, DL, VT, N0.getOperand(1)); + assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!"); + return DAG.getNode(ISD::AND, DL, VT, X, Y); } // fold (aext (load x)) -> (aext (truncate (extload x))) @@ -12153,13 +12585,9 @@ SDValue DAGCombiner::visitAssertExt(SDNode *N) { // This eliminates the later assert: // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN + SDLoc DL(N); SDValue BigA = N0.getOperand(0); EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT(); - assert(BigA_AssertVT.bitsLE(N0.getValueType()) && - "Asserting zero/sign-extended bits to a type larger than the " - "truncated destination does not provide information"); - - SDLoc DL(N); EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT; SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT); SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(), @@ -12175,10 +12603,6 @@ SDValue DAGCombiner::visitAssertExt(SDNode *N) { Opcode == ISD::AssertZext) { SDValue BigA = N0.getOperand(0); EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT(); - assert(BigA_AssertVT.bitsLE(N0.getValueType()) && - "Asserting zero/sign-extended bits to a type larger than the " - "truncated destination does not provide information"); - if (AssertVT.bitsLT(BigA_AssertVT)) { SDLoc DL(N); SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(), @@ -12296,13 +12720,11 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) { unsigned ActiveBits = 0; if (Mask.isMask()) { ActiveBits = Mask.countTrailingOnes(); - } else if (Mask.isShiftedMask()) { - ShAmt = Mask.countTrailingZeros(); - APInt ShiftedMask = Mask.lshr(ShAmt); - ActiveBits = ShiftedMask.countTrailingOnes(); + } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) { HasShiftedOffset = true; - } else + } else { return SDValue(); + } ExtType = ISD::ZEXTLOAD; ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits); @@ -12919,21 +13341,6 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { if (SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); - // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry) - // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry) - // When the adde's carry is not used. - if ((N0.getOpcode() == ISD::ADDE || N0.getOpcode() == ISD::ADDCARRY) && - N0.hasOneUse() && !N0.getNode()->hasAnyUseOfValue(1) && - // We only do for addcarry before legalize operation - ((!LegalOperations && N0.getOpcode() == ISD::ADDCARRY) || - TLI.isOperationLegal(N0.getOpcode(), VT))) { - SDLoc SL(N); - auto X = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0)); - auto Y = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1)); - auto VTs = DAG.getVTList(VT, N0->getValueType(1)); - return DAG.getNode(N0.getOpcode(), SL, VTs, X, Y, N0.getOperand(2)); - } - // fold (truncate (extract_subvector(ext x))) -> // (extract_subvector x) // TODO: This can be generalized to cover cases where the truncate and extract @@ -12978,6 +13385,22 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { } } break; + case ISD::ADDE: + case ISD::ADDCARRY: + // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry) + // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry) + // When the adde's carry is not used. + // We only do for addcarry before legalize operation + if (((!LegalOperations && N0.getOpcode() == ISD::ADDCARRY) || + TLI.isOperationLegal(N0.getOpcode(), VT)) && + N0.hasOneUse() && !N0->hasAnyUseOfValue(1)) { + SDLoc DL(N); + SDValue X = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0)); + SDValue Y = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1)); + SDVTList VTs = DAG.getVTList(VT, N0->getValueType(1)); + return DAG.getNode(N0.getOpcode(), DL, VTs, X, Y, N0.getOperand(2)); + } + break; case ISD::USUBSAT: // Truncate the USUBSAT only if LHS is a known zero-extension, its not // enough to know that the upper bits are zero we must ensure that we don't @@ -13111,7 +13534,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { (!LegalTypes || (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() && TLI.isTypeLegal(VT.getVectorElementType()))) && - N0.getOpcode() == ISD::BUILD_VECTOR && N0.getNode()->hasOneUse() && + N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() && cast<BuildVectorSDNode>(N0)->isConstant()) return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(), VT.getVectorElementType()); @@ -13179,8 +13602,8 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { // This often reduces constant pool loads. if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) || (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) && - N0.getNode()->hasOneUse() && VT.isInteger() && - !VT.isVector() && !N0.getValueType().isVector()) { + N0->hasOneUse() && VT.isInteger() && !VT.isVector() && + !N0.getValueType().isVector()) { SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0)); AddToWorklist(NewConv.getNode()); @@ -13228,9 +13651,9 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { // (xor (bitcast cst), (bitcast x)), 0), // signbit) // (xor (bitcast cst) (build_pair flipbit, flipbit)) - if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse() && - isa<ConstantFPSDNode>(N0.getOperand(0)) && - VT.isInteger() && !VT.isVector()) { + if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() && + isa<ConstantFPSDNode>(N0.getOperand(0)) && VT.isInteger() && + !VT.isVector()) { unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits(); EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth); if (isTypeLegal(IntXVT)) { @@ -13312,8 +13735,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { if (Op.getOpcode() == ISD::BITCAST && Op.getOperand(0).getValueType() == VT) return SDValue(Op.getOperand(0)); - if (Op.isUndef() || ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) || - ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode())) + if (Op.isUndef() || isAnyConstantBuildVector(Op)) return DAG.getBitcast(VT, Op); return SDValue(); }; @@ -13353,6 +13775,14 @@ SDValue DAGCombiner::visitFREEZE(SDNode *N) { if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, /*PoisonOnly*/ false)) return N0; + // Fold freeze(bitcast(x)) -> bitcast(freeze(x)). + // TODO: Replace with pushFreezeToPreventPoisonFromPropagating fold. + if (N0.getOpcode() == ISD::BITCAST) + return DAG.getBitcast(N->getValueType(0), + DAG.getNode(ISD::FREEZE, SDLoc(N0), + N0.getOperand(0).getValueType(), + N0.getOperand(0))); + return SDValue(); } @@ -13444,7 +13874,7 @@ static bool isContractableFMUL(const TargetOptions &Options, SDValue N) { // Returns true if `N` can assume no infinities involved in its computation. static bool hasNoInfs(const TargetOptions &Options, SDValue N) { - return Options.NoInfsFPMath || N.getNode()->getFlags().hasNoInfs(); + return Options.NoInfsFPMath || N->getFlags().hasNoInfs(); } /// Try to perform FMA combining on a given FADD node. @@ -13498,7 +13928,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), // prefer to fold the multiply with fewer uses. if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) { - if (N0.getNode()->use_size() > N1.getNode()->use_size()) + if (N0->use_size() > N1->use_size()) std::swap(N0, N1); } @@ -13728,7 +14158,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)), // prefer to fold the multiply with fewer uses. if (isContractableFMUL(N0) && isContractableFMUL(N1) && - (N0.getNode()->use_size() > N1.getNode()->use_size())) { + (N0->use_size() > N1->use_size())) { // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b)) if (SDValue V = tryToFoldXSubYZ(N0, N1)) return V; @@ -14851,7 +15281,7 @@ SDValue DAGCombiner::visitFREM(SDNode *N) { // fold (frem c1, c2) -> fmod(c1,c2) if (SDValue C = DAG.FoldConstantArithmetic(ISD::FREM, SDLoc(N), VT, {N0, N1})) return C; - + if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; @@ -15174,7 +15604,7 @@ static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) { // This means this is also safe for a signed input and unsigned output, since // a negative input would lead to undefined behavior. unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned; - unsigned OutputSize = (int)VT.getScalarSizeInBits() - IsOutputSigned; + unsigned OutputSize = (int)VT.getScalarSizeInBits(); unsigned ActualSize = std::min(InputSize, OutputSize); const fltSemantics &sem = DAG.EVTToAPFloatSemantics(N0.getValueType()); @@ -15265,7 +15695,7 @@ SDValue DAGCombiner::visitFP_ROUND(SDNode *N) { } // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y) - if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse()) { + if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse()) { SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT, N0.getOperand(0), N1); AddToWorklist(Tmp.getNode()); @@ -15709,7 +16139,7 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail // out. There is no reason to make this a preinc/predec. if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) || - Ptr.getNode()->hasOneUse()) + Ptr->hasOneUse()) return false; // Ask the target to do addressing mode selection. @@ -15769,8 +16199,8 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { // a copy of the original base pointer. SmallVector<SDNode *, 16> OtherUses; if (isa<ConstantSDNode>(Offset)) - for (SDNode::use_iterator UI = BasePtr.getNode()->use_begin(), - UE = BasePtr.getNode()->use_end(); + for (SDNode::use_iterator UI = BasePtr->use_begin(), + UE = BasePtr->use_end(); UI != UE; ++UI) { SDUse &Use = UI.getUse(); // Skip the use that is Ptr and uses of other results from BasePtr's @@ -15808,7 +16238,7 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { // Now check for #3 and #4. bool RealUse = false; - for (SDNode *Use : Ptr.getNode()->uses()) { + for (SDNode *Use : Ptr->uses()) { if (Use == N) continue; if (SDNode::hasPredecessorHelper(Use, Visited, Worklist)) @@ -15841,7 +16271,7 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) { ++PreIndexedNodes; ++NodesCombined; LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: "; - Result.getNode()->dump(&DAG); dbgs() << '\n'); + Result.dump(&DAG); dbgs() << '\n'); WorklistRemover DeadNodes(*this); if (IsLoad) { DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0)); @@ -15931,7 +16361,7 @@ static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse, return false; SmallPtrSet<const SDNode *, 32> Visited; - for (SDNode *Use : BasePtr.getNode()->uses()) { + for (SDNode *Use : BasePtr->uses()) { if (Use == Ptr.getNode()) continue; @@ -15968,7 +16398,7 @@ static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad, const TargetLowering &TLI) { if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad, IsMasked, Ptr, TLI) || - Ptr.getNode()->hasOneUse()) + Ptr->hasOneUse()) return nullptr; // Try turning it into a post-indexed load / store except when @@ -16028,9 +16458,8 @@ bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) { BasePtr, Offset, AM); ++PostIndexedNodes; ++NodesCombined; - LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); - dbgs() << "\nWith: "; Result.getNode()->dump(&DAG); - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: "; + Result.dump(&DAG); dbgs() << '\n'); WorklistRemover DeadNodes(*this); if (IsLoad) { DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0)); @@ -16271,7 +16700,7 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { // Now we replace use of chain2 with chain1. This makes the second load // isomorphic to the one we are deleting, and thus makes this load live. LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG); - dbgs() << "\nWith chain: "; Chain.getNode()->dump(&DAG); + dbgs() << "\nWith chain: "; Chain.dump(&DAG); dbgs() << "\n"); WorklistRemover DeadNodes(*this); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain); @@ -16302,7 +16731,7 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { } else Index = DAG.getUNDEF(N->getValueType(1)); LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG); - dbgs() << "\nWith: "; Undef.getNode()->dump(&DAG); + dbgs() << "\nWith: "; Undef.dump(&DAG); dbgs() << " and 2 other values\n"); WorklistRemover DeadNodes(*this); DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef); @@ -17014,11 +17443,19 @@ ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo, // Check that it is legal on the target to do this. It is legal if the new // VT we're shrinking to (i8/i16/i32) is legal or we're still before type - // legalization (and the target doesn't explicitly think this is a bad idea). + // legalization. If the source type is legal, but the store type isn't, see + // if we can use a truncating store. MVT VT = MVT::getIntegerVT(NumBytes * 8); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (!DC->isTypeLegal(VT)) + bool UseTruncStore; + if (DC->isTypeLegal(VT)) + UseTruncStore = false; + else if (TLI.isTypeLegal(IVal.getValueType()) && + TLI.isTruncStoreLegal(IVal.getValueType(), VT)) + UseTruncStore = true; + else return SDValue(); + // Check that the target doesn't think this is a bad idea. if (St->getMemOperand() && !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT, *St->getMemOperand())) @@ -17046,10 +17483,15 @@ ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo, Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(StOffset), DL); } + ++OpsNarrowed; + if (UseTruncStore) + return DAG.getTruncStore(St->getChain(), SDLoc(St), IVal, Ptr, + St->getPointerInfo().getWithOffset(StOffset), + VT, St->getOriginalAlign()); + // Truncate down to the new size. IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal); - ++OpsNarrowed; return DAG .getStore(St->getChain(), SDLoc(St), IVal, Ptr, St->getPointerInfo().getWithOffset(StOffset), @@ -17070,11 +17512,15 @@ SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) { SDValue Ptr = ST->getBasePtr(); EVT VT = Value.getValueType(); - if (ST->isTruncatingStore() || VT.isVector() || !Value.hasOneUse()) + if (ST->isTruncatingStore() || VT.isVector()) return SDValue(); unsigned Opc = Value.getOpcode(); + if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) || + !Value.hasOneUse()) + return SDValue(); + // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst // is a byte mask indicating a consecutive number of bytes, check to see if // Y is known to provide just those bytes. If so, we try to replace the @@ -17099,8 +17545,7 @@ SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) { if (!EnableReduceLoadOpStoreWidth) return SDValue(); - if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) || - Value.getOperand(1).getOpcode() != ISD::Constant) + if (Value.getOperand(1).getOpcode() != ISD::Constant) return SDValue(); SDValue N0 = Value.getOperand(0); @@ -17256,14 +17701,13 @@ SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) { // (A + c1) * c3 // (A + c2) * c3 // We're checking for cases where we have common "c3 * A" expressions. -bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, - SDValue &AddNode, - SDValue &ConstNode) { +bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode, + SDValue ConstNode) { APInt Val; // If the add only has one use, and the target thinks the folding is // profitable or does not lead to worse code, this would be OK to do. - if (AddNode.getNode()->hasOneUse() && + if (AddNode->hasOneUse() && TLI.isMulAddWithConstProfitable(AddNode, ConstNode)) return true; @@ -17397,7 +17841,9 @@ bool DAGCombiner::mergeStoresOfConstantsOrVecElts( if (isa<ConstantFPSDNode>(Val)) { // Not clear how to truncate FP values. return false; - } else if (auto *C = dyn_cast<ConstantSDNode>(Val)) + } + + if (auto *C = dyn_cast<ConstantSDNode>(Val)) Val = DAG.getConstant(C->getAPIntValue() .zextOrTrunc(Val.getValueSizeInBits()) .zextOrTrunc(ElementSizeBits), @@ -17491,7 +17937,7 @@ bool DAGCombiner::mergeStoresOfConstantsOrVecElts( if (!UseTrunc) { NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(), FirstInChain->getPointerInfo(), - FirstInChain->getAlign(), Flags.getValue(), AAInfo); + FirstInChain->getAlign(), *Flags, AAInfo); } else { // Must be realized as a trunc store EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType()); @@ -17503,7 +17949,7 @@ bool DAGCombiner::mergeStoresOfConstantsOrVecElts( NewStore = DAG.getTruncStore( NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(), FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/, - FirstInChain->getAlign(), Flags.getValue(), AAInfo); + FirstInChain->getAlign(), *Flags, AAInfo); } // Replace all merged stores with the new store. @@ -17671,11 +18117,9 @@ void DAGCombiner::getStoreMergeCandidates( } } -// We need to check that merging these stores does not cause a loop in -// the DAG. Any store candidate may depend on another candidate -// indirectly through its operand (we already consider dependencies -// through the chain). Check in parallel by searching up from -// non-chain operands of candidates. +// We need to check that merging these stores does not cause a loop in the +// DAG. Any store candidate may depend on another candidate indirectly through +// its operands. Check in parallel by searching up from operands of candidates. bool DAGCombiner::checkMergeStoreCandidatesForDependencies( SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores, SDNode *RootNode) { @@ -17709,8 +18153,13 @@ bool DAGCombiner::checkMergeStoreCandidatesForDependencies( SDNode *N = StoreNodes[i].MemNode; // Of the 4 Store Operands: // * Chain (Op 0) -> We have already considered these - // in candidate selection and can be - // safely ignored + // in candidate selection, but only by following the + // chain dependencies. We could still have a chain + // dependency to a load, that has a non-chain dep to + // another load, that depends on a store, etc. So it is + // possible to have dependencies that consist of a mix + // of chain and non-chain deps, and we need to include + // chain operands in the analysis here.. // * Value (Op 1) -> Cycles may happen (e.g. through load chains) // * Address (Op 2) -> Merged addresses may only vary by a fixed constant, // but aren't necessarily fromt the same base node, so @@ -17718,7 +18167,7 @@ bool DAGCombiner::checkMergeStoreCandidatesForDependencies( // * (Op 3) -> Represents the pre or post-indexing offset (or undef for // non-indexed stores). Not constant on all targets (e.g. ARM) // and so can participate in a cycle. - for (unsigned j = 1; j < N->getNumOperands(); ++j) + for (unsigned j = 0; j < N->getNumOperands(); ++j) Worklist.push_back(N->getOperand(j).getNode()); } // Search through DAG. We can stop early if we find a store node. @@ -17793,7 +18242,7 @@ bool DAGCombiner::tryStoreMergeOfConstants( while (NumConsecutiveStores >= 2) { LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; unsigned FirstStoreAS = FirstInChain->getAddressSpace(); - unsigned FirstStoreAlign = FirstInChain->getAlignment(); + Align FirstStoreAlign = FirstInChain->getAlign(); unsigned LastLegalType = 1; unsigned LastLegalVectorType = 1; bool LastIntegerTrunc = false; @@ -17881,7 +18330,7 @@ bool DAGCombiner::tryStoreMergeOfConstants( unsigned NumSkip = 1; while ((NumSkip < NumConsecutiveStores) && (NumSkip < FirstZeroAfterNonZero) && - (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) + (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign)) NumSkip++; StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); @@ -17920,7 +18369,7 @@ bool DAGCombiner::tryStoreMergeOfExtracts( while (NumConsecutiveStores >= 2) { LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; unsigned FirstStoreAS = FirstInChain->getAddressSpace(); - unsigned FirstStoreAlign = FirstInChain->getAlignment(); + Align FirstStoreAlign = FirstInChain->getAlign(); unsigned NumStoresToMerge = 1; for (unsigned i = 0; i < NumConsecutiveStores; ++i) { // Find a legal type for the vector store. @@ -17951,7 +18400,7 @@ bool DAGCombiner::tryStoreMergeOfExtracts( // improved. Drop as many candidates as we can here. unsigned NumSkip = 1; while ((NumSkip < NumConsecutiveStores) && - (StoreNodes[NumSkip].MemNode->getAlignment() <= FirstStoreAlign)) + (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign)) NumSkip++; StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip); @@ -18248,7 +18697,7 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes, for (unsigned i = 0; i < NumElem; ++i) { SDValue Val = StoreNodes[i].MemNode->getOperand(1); CombineTo(StoreNodes[i].MemNode, NewStore); - if (Val.getNode()->use_empty()) + if (Val->use_empty()) recursivelyDeleteUnusedNodes(Val.getNode()); } @@ -18398,6 +18847,7 @@ SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) { default: llvm_unreachable("Unknown FP type"); case MVT::f16: // We don't do this for these yet. + case MVT::bf16: case MVT::f80: case MVT::f128: case MVT::ppcf128: @@ -18405,7 +18855,6 @@ SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) { case MVT::f32: if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) || TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) { - ; Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF(). bitcastToAPInt().getZExtValue(), SDLoc(CFP), MVT::i32); @@ -18417,7 +18866,6 @@ SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) { if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations && ST->isSimple()) || TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) { - ; Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt(). getZExtValue(), SDLoc(CFP), MVT::i64); return DAG.getStore(Chain, DL, Tmp, @@ -18611,7 +19059,7 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { // truncating store. We can do this even if this is already a truncstore. if ((Value.getOpcode() == ISD::FP_ROUND || Value.getOpcode() == ISD::TRUNCATE) && - Value.getNode()->hasOneUse() && ST->isUnindexed() && + Value->hasOneUse() && ST->isUnindexed() && TLI.canCombineTruncStore(Value.getOperand(0).getValueType(), ST->getMemoryVT(), LegalOperations)) { return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0), @@ -18874,6 +19322,14 @@ SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) { } } + // If we failed to find a match, see if we can replace an UNDEF shuffle + // operand. + if (ElementOffset == -1 && Y.isUndef() && + InsertVal0.getValueType() == Y.getValueType()) { + ElementOffset = Mask.size(); + Y = InsertVal0; + } + if (ElementOffset != -1) { SmallVector<int, 16> NewMask(Mask.begin(), Mask.end()); @@ -18972,10 +19428,9 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) { if (VT.isScalableVector()) return DAG.getSplatVector(VT, DL, InVal); - else { - SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), InVal); - return DAG.getBuildVector(VT, DL, Ops); - } + + SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), InVal); + return DAG.getBuildVector(VT, DL, Ops); } return SDValue(); } @@ -18987,9 +19442,19 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { // We must know which element is being inserted for folds below here. unsigned Elt = IndexC->getZExtValue(); + if (SDValue Shuf = combineInsertEltToShuffle(N, Elt)) return Shuf; + // Handle <1 x ???> vector insertion special cases. + if (VT.getVectorNumElements() == 1) { + // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y + if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + InVal.getOperand(0).getValueType() == VT && + isNullConstant(InVal.getOperand(1))) + return InVal.getOperand(0); + } + // Canonicalize insert_vector_elt dag nodes. // Example: // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1) @@ -19010,36 +19475,84 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { } } - // If we can't generate a legal BUILD_VECTOR, exit - if (LegalOperations && !TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) - return SDValue(); + // Attempt to fold the insertion into a legal BUILD_VECTOR. + if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) { + auto UpdateBuildVector = [&](SmallVectorImpl<SDValue> &Ops) { + assert(Ops.size() == NumElts && "Unexpected vector size"); - // Check that the operand is a BUILD_VECTOR (or UNDEF, which can essentially - // be converted to a BUILD_VECTOR). Fill in the Ops vector with the - // vector elements. - SmallVector<SDValue, 8> Ops; - // Do not combine these two vectors if the output vector will not replace - // the input vector. - if (InVec.getOpcode() == ISD::BUILD_VECTOR && InVec.hasOneUse()) { - Ops.append(InVec.getNode()->op_begin(), - InVec.getNode()->op_end()); - } else if (InVec.isUndef()) { - Ops.append(NumElts, DAG.getUNDEF(InVal.getValueType())); - } else { - return SDValue(); - } - assert(Ops.size() == NumElts && "Unexpected vector size"); + // Insert the element + if (Elt < Ops.size()) { + // All the operands of BUILD_VECTOR must have the same type; + // we enforce that here. + EVT OpVT = Ops[0].getValueType(); + Ops[Elt] = + OpVT.isInteger() ? DAG.getAnyExtOrTrunc(InVal, DL, OpVT) : InVal; + } + + // Return the new vector + return DAG.getBuildVector(VT, DL, Ops); + }; + + // Check that the operand is a BUILD_VECTOR (or UNDEF, which can essentially + // be converted to a BUILD_VECTOR). Fill in the Ops vector with the + // vector elements. + SmallVector<SDValue, 8> Ops; + + // Do not combine these two vectors if the output vector will not replace + // the input vector. + if (InVec.getOpcode() == ISD::BUILD_VECTOR && InVec.hasOneUse()) { + Ops.append(InVec->op_begin(), InVec->op_end()); + return UpdateBuildVector(Ops); + } + + if (InVec.getOpcode() == ISD::SCALAR_TO_VECTOR && InVec.hasOneUse()) { + Ops.push_back(InVec.getOperand(0)); + Ops.append(NumElts - 1, DAG.getUNDEF(InVec.getOperand(0).getValueType())); + return UpdateBuildVector(Ops); + } + + if (InVec.isUndef()) { + Ops.append(NumElts, DAG.getUNDEF(InVal.getValueType())); + return UpdateBuildVector(Ops); + } + + // If we're inserting into the end of a vector as part of an sequence, see + // if we can create a BUILD_VECTOR by following the sequence back up the + // chain. + if (Elt == (NumElts - 1)) { + SmallVector<SDValue> ReverseInsertions; + ReverseInsertions.push_back(InVal); + + EVT MaxEltVT = InVal.getValueType(); + SDValue CurVec = InVec; + for (unsigned I = 1; I != NumElts; ++I) { + if (CurVec.getOpcode() != ISD::INSERT_VECTOR_ELT || !CurVec.hasOneUse()) + break; - // Insert the element - if (Elt < Ops.size()) { - // All the operands of BUILD_VECTOR must have the same type; - // we enforce that here. - EVT OpVT = Ops[0].getValueType(); - Ops[Elt] = OpVT.isInteger() ? DAG.getAnyExtOrTrunc(InVal, DL, OpVT) : InVal; + auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2)); + if (!CurIdx || CurIdx->getAPIntValue() != ((NumElts - 1) - I)) + break; + SDValue CurVal = CurVec.getOperand(1); + ReverseInsertions.push_back(CurVal); + if (VT.isInteger()) { + EVT CurValVT = CurVal.getValueType(); + MaxEltVT = MaxEltVT.bitsGE(CurValVT) ? MaxEltVT : CurValVT; + } + CurVec = CurVec.getOperand(0); + } + + if (ReverseInsertions.size() == NumElts) { + for (unsigned I = 0; I != NumElts; ++I) { + SDValue Val = ReverseInsertions[(NumElts - 1) - I]; + Val = VT.isInteger() ? DAG.getAnyExtOrTrunc(Val, DL, MaxEltVT) : Val; + Ops.push_back(Val); + } + return DAG.getBuildVector(VT, DL, Ops); + } + } } - // Return the new vector - return DAG.getBuildVector(VT, DL, Ops); + return SDValue(); } SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT, @@ -19088,47 +19601,33 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT, SDValue NewPtr = TLI.getVectorElementPointer(DAG, OriginalLoad->getBasePtr(), InVecVT, EltNo); - // The replacement we need to do here is a little tricky: we need to - // replace an extractelement of a load with a load. - // Use ReplaceAllUsesOfValuesWith to do the replacement. - // Note that this replacement assumes that the extractvalue is the only - // use of the load; that's okay because we don't want to perform this - // transformation in other cases anyway. + // We are replacing a vector load with a scalar load. The new load must have + // identical memory op ordering to the original. SDValue Load; - SDValue Chain; if (ResultVT.bitsGT(VecEltVT)) { // If the result type of vextract is wider than the load, then issue an // extending load instead. - ISD::LoadExtType ExtType = TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, - VecEltVT) - ? ISD::ZEXTLOAD - : ISD::EXTLOAD; - Load = DAG.getExtLoad(ExtType, SDLoc(EVE), ResultVT, - OriginalLoad->getChain(), NewPtr, MPI, VecEltVT, - Alignment, OriginalLoad->getMemOperand()->getFlags(), + ISD::LoadExtType ExtType = + TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) ? ISD::ZEXTLOAD + : ISD::EXTLOAD; + Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(), + NewPtr, MPI, VecEltVT, Alignment, + OriginalLoad->getMemOperand()->getFlags(), OriginalLoad->getAAInfo()); - Chain = Load.getValue(1); + DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load); } else { - Load = DAG.getLoad( - VecEltVT, SDLoc(EVE), OriginalLoad->getChain(), NewPtr, MPI, Alignment, - OriginalLoad->getMemOperand()->getFlags(), OriginalLoad->getAAInfo()); - Chain = Load.getValue(1); + // The result type is narrower or the same width as the vector element + Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI, + Alignment, OriginalLoad->getMemOperand()->getFlags(), + OriginalLoad->getAAInfo()); + DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load); if (ResultVT.bitsLT(VecEltVT)) - Load = DAG.getNode(ISD::TRUNCATE, SDLoc(EVE), ResultVT, Load); + Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load); else Load = DAG.getBitcast(ResultVT, Load); } - WorklistRemover DeadNodes(*this); - SDValue From[] = { SDValue(EVE, 0), SDValue(OriginalLoad, 1) }; - SDValue To[] = { Load, Chain }; - DAG.ReplaceAllUsesOfValuesWith(From, To, 2); - // Make sure to revisit this node to clean it up; it will usually be dead. - AddToWorklist(EVE); - // Since we're explicitly calling ReplaceAllUses, add the new node to the - // worklist explicitly as well. - AddToWorklistWithUsers(Load.getNode()); ++OpsNarrowed; - return SDValue(EVE, 0); + return Load; } /// Transform a vector binary operation into a scalar binary operation by moving @@ -19140,7 +19639,7 @@ static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG, SDValue Index = ExtElt->getOperand(1); auto *IndexC = dyn_cast<ConstantSDNode>(Index); if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() || - Vec.getNode()->getNumValues() != 1) + Vec->getNumValues() != 1) return SDValue(); // Targets may want to avoid this to prevent an expensive register transfer. @@ -19196,8 +19695,9 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { // EXTRACT_VECTOR_ELT may widen the extracted vector. SDValue InOp = VecOp.getOperand(0); if (InOp.getValueType() != ScalarVT) { - assert(InOp.getValueType().isInteger() && ScalarVT.isInteger()); - return DAG.getSExtOrTrunc(InOp, DL, ScalarVT); + assert(InOp.getValueType().isInteger() && ScalarVT.isInteger() && + InOp.getValueType().bitsGT(ScalarVT)); + return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, InOp); } return InOp; } @@ -19655,7 +20155,7 @@ SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) { if (!isa<ConstantSDNode>(ShiftAmtVal)) return SDValue(); - uint64_t ShiftAmt = In.getNode()->getConstantOperandVal(1); + uint64_t ShiftAmt = In.getConstantOperandVal(1); // The extracted value is not extracted at the right position if (ShiftAmt != i * ScalarTypeBitsize) @@ -20096,18 +20596,39 @@ SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) { int Left = 2 * In; int Right = 2 * In + 1; SmallVector<int, 8> Mask(NumElems, -1); - for (unsigned i = 0; i != NumElems; ++i) { - if (VectorMask[i] == Left) { - Mask[i] = i; - VectorMask[i] = In; - } else if (VectorMask[i] == Right) { - Mask[i] = i + NumElems; - VectorMask[i] = In; + SDValue L = Shuffles[Left]; + ArrayRef<int> LMask; + bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE && + L.use_empty() && L.getOperand(1).isUndef() && + L.getOperand(0).getValueType() == L.getValueType(); + if (IsLeftShuffle) { + LMask = cast<ShuffleVectorSDNode>(L.getNode())->getMask(); + L = L.getOperand(0); + } + SDValue R = Shuffles[Right]; + ArrayRef<int> RMask; + bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE && + R.use_empty() && R.getOperand(1).isUndef() && + R.getOperand(0).getValueType() == R.getValueType(); + if (IsRightShuffle) { + RMask = cast<ShuffleVectorSDNode>(R.getNode())->getMask(); + R = R.getOperand(0); + } + for (unsigned I = 0; I != NumElems; ++I) { + if (VectorMask[I] == Left) { + Mask[I] = I; + if (IsLeftShuffle) + Mask[I] = LMask[I]; + VectorMask[I] = In; + } else if (VectorMask[I] == Right) { + Mask[I] = I + NumElems; + if (IsRightShuffle) + Mask[I] = RMask[I] + NumElems; + VectorMask[I] = In; } } - Shuffles[In] = - DAG.getVectorShuffle(VT, DL, Shuffles[Left], Shuffles[Right], Mask); + Shuffles[In] = DAG.getVectorShuffle(VT, DL, L, R, Mask); } } return Shuffles[0]; @@ -20695,7 +21216,7 @@ static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract, const TargetLowering &TLI = DAG.getTargetLoweringInfo(); SDValue BinOp = Extract->getOperand(0); unsigned BinOpcode = BinOp.getOpcode(); - if (!TLI.isBinOp(BinOpcode) || BinOp.getNode()->getNumValues() != 1) + if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1) return SDValue(); EVT VecVT = BinOp.getValueType(); @@ -20744,7 +21265,7 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG, const TargetLowering &TLI = DAG.getTargetLoweringInfo(); SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0)); unsigned BOpcode = BinOp.getOpcode(); - if (!TLI.isBinOp(BOpcode) || BinOp.getNode()->getNumValues() != 1) + if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1) return SDValue(); // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be @@ -20803,8 +21324,8 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG, BinOp.getOperand(0), NewExtIndex); SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, BinOp.getOperand(1), NewExtIndex); - SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y, - BinOp.getNode()->getFlags()); + SDValue NarrowBinOp = + DAG.getNode(BOpcode, DL, NarrowBVT, X, Y, BinOp->getFlags()); return DAG.getBitcast(VT, NarrowBinOp); } @@ -21085,6 +21606,12 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { } } + // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V) + if (V.getOpcode() == ISD::SPLAT_VECTOR) + if (DAG.isConstantValueOfAnyType(V.getOperand(0)) || V.hasOneUse()) + if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT)) + return DAG.getSplatVector(NVT, SDLoc(N), V.getOperand(0)); + // Try to move vector bitcast after extract_subv by scaling extraction index: // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index') if (V.getOpcode() == ISD::BITCAST && @@ -21450,9 +21977,10 @@ static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN, SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT); if (SVT != VT.getScalarType()) for (SDValue &Op : Ops) - Op = TLI.isZExtFree(Op.getValueType(), SVT) - ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT) - : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT); + Op = Op.isUndef() ? DAG.getUNDEF(SVT) + : (TLI.isZExtFree(Op.getValueType(), SVT) + ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT) + : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT)); return DAG.getBuildVector(VT, SDLoc(SVN), Ops); } @@ -21582,6 +22110,13 @@ static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf, SelectionDAG &DAG) { if (!Shuf->getOperand(1).isUndef()) return SDValue(); + + // If the inner operand is a known splat with no undefs, just return that directly. + // TODO: Create DemandedElts mask from Shuf's mask. + // TODO: Allow undef elements and merge with the shuffle code below. + if (DAG.isSplatValue(Shuf->getOperand(0), /*AllowUndefs*/ false)) + return Shuf->getOperand(0); + auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0)); if (!Splat || !Splat->isSplat()) return SDValue(); @@ -21628,6 +22163,53 @@ static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf, NewMask); } +// Combine shuffles of bitcasts into a shuffle of the bitcast type, providing +// the mask can be treated as a larger type. +static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN, + SelectionDAG &DAG, + const TargetLowering &TLI, + bool LegalOperations) { + SDValue Op0 = SVN->getOperand(0); + SDValue Op1 = SVN->getOperand(1); + EVT VT = SVN->getValueType(0); + if (Op0.getOpcode() != ISD::BITCAST) + return SDValue(); + EVT InVT = Op0.getOperand(0).getValueType(); + if (!InVT.isVector() || + (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST || + Op1.getOperand(0).getValueType() != InVT))) + return SDValue(); + if (isAnyConstantBuildVector(Op0.getOperand(0)) && + (Op1.isUndef() || isAnyConstantBuildVector(Op1.getOperand(0)))) + return SDValue(); + + int VTLanes = VT.getVectorNumElements(); + int InLanes = InVT.getVectorNumElements(); + if (VTLanes <= InLanes || VTLanes % InLanes != 0 || + (LegalOperations && + !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, InVT))) + return SDValue(); + int Factor = VTLanes / InLanes; + + // Check that each group of lanes in the mask are either undef or make a valid + // mask for the wider lane type. + ArrayRef<int> Mask = SVN->getMask(); + SmallVector<int> NewMask; + if (!widenShuffleMaskElts(Factor, Mask, NewMask)) + return SDValue(); + + if (!TLI.isShuffleMaskLegal(NewMask, InVT)) + return SDValue(); + + // Create the new shuffle with the new mask and bitcast it back to the + // original type. + SDLoc DL(SVN); + Op0 = Op0.getOperand(0); + Op1 = Op1.isUndef() ? DAG.getUNDEF(InVT) : Op1.getOperand(0); + SDValue NewShuf = DAG.getVectorShuffle(InVT, DL, Op0, Op1, NewMask); + return DAG.getBitcast(VT, NewShuf); +} + /// Combine shuffle of shuffle of the form: /// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf, @@ -21839,7 +22421,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) { int SplatIndex = SVN->getSplatIndex(); if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) && - TLI.isBinOp(N0.getOpcode()) && N0.getNode()->getNumValues() == 1) { + TLI.isBinOp(N0.getOpcode()) && N0->getNumValues() == 1) { // splat (vector_bo L, R), Index --> // splat (scalar_bo (extelt L, Index), (extelt R, Index)) SDValue L = N0.getOperand(0), R = N0.getOperand(1); @@ -21848,13 +22430,26 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { SDValue Index = DAG.getVectorIdxConstant(SplatIndex, DL); SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index); SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index); - SDValue NewBO = DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR, - N0.getNode()->getFlags()); + SDValue NewBO = + DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR, N0->getFlags()); SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO); SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0); return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask); } + // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x) + // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x) + if ((!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) && + N0.hasOneUse()) { + if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0) + return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(0)); + + if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT) + if (auto *Idx = dyn_cast<ConstantSDNode>(N0.getOperand(2))) + if (Idx->getAPIntValue() == SplatIndex) + return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(1)); + } + // If this is a bit convert that changes the element type of the vector but // not the number of vector elements, look through it. Be careful not to // look though conversions that change things like v4f32 to v2f64. @@ -22078,6 +22673,11 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { } } + // Match shuffles of bitcasts, so long as the mask can be treated as the + // larger type. + if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations)) + return V; + // Compute the combined shuffle mask for a shuffle with SV0 as the first // operand, and SV1 as the second operand. // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false @@ -22409,6 +23009,11 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { N1.getOperand(1) == N2 && N1.getOperand(0).getValueType() == VT) return N1.getOperand(0); + // Simplify scalar inserts into an undef vector: + // insert_subvector undef, (splat X), N2 -> splat X + if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR) + return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, N1.getOperand(0)); + // If we are inserting a bitcast value into an undef, with the same // number of elements, just use the bitcast input of the extract. // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 -> @@ -22556,6 +23161,16 @@ SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) { + SDValue N0 = N->getOperand(0); + + // fold (fp_to_bf16 (bf16_to_fp op)) -> op + if (N0->getOpcode() == ISD::BF16_TO_FP) + return N0->getOperand(0); + + return SDValue(); +} + SDValue DAGCombiner::visitVECREDUCE(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N0.getValueType(); @@ -22583,6 +23198,19 @@ SDValue DAGCombiner::visitVECREDUCE(SDNode *N) { return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0); } + // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val) + // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val) + if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && + TLI.isTypeLegal(N0.getOperand(1).getValueType())) { + SDValue Vec = N0.getOperand(0); + SDValue Subvec = N0.getOperand(1); + if ((Opcode == ISD::VECREDUCE_OR && + (N0.getOperand(0).isUndef() || isNullOrNullSplat(Vec))) || + (Opcode == ISD::VECREDUCE_AND && + (N0.getOperand(0).isUndef() || isAllOnesOrAllOnesSplat(Vec)))) + return DAG.getNode(Opcode, SDLoc(N), N->getValueType(0), Subvec); + } + return SDValue(); } @@ -22886,7 +23514,7 @@ SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, // Check to see if we got a select_cc back (to turn into setcc/select). // Otherwise, just return whatever node we got back, like fabs. if (SCC.getOpcode() == ISD::SELECT_CC) { - const SDNodeFlags Flags = N0.getNode()->getFlags(); + const SDNodeFlags Flags = N0->getFlags(); SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0), N0.getValueType(), SCC.getOperand(0), SCC.getOperand(1), @@ -23556,6 +24184,27 @@ SDValue DAGCombiner::BuildUDIV(SDNode *N) { return SDValue(); } +/// Given an ISD::SREM node expressing a remainder by constant power of 2, +/// return a DAG expression that will generate the same value. +SDValue DAGCombiner::BuildSREMPow2(SDNode *N) { + ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1)); + if (!C) + return SDValue(); + + // Avoid division by zero. + if (C->isZero()) + return SDValue(); + + SmallVector<SDNode *, 8> Built; + if (SDValue S = TLI.BuildSREMPow2(N, C->getAPIntValue(), DAG, Built)) { + for (SDNode *N : Built) + AddToWorklist(N); + return S; + } + + return SDValue(); +} + /// Determines the LogBase2 value for a non-null input value using the /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V). SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) { @@ -23865,9 +24514,8 @@ bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const { auto &Size0 = MUC0.NumBytes; auto &Size1 = MUC1.NumBytes; if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 && - Size0.hasValue() && Size1.hasValue() && *Size0 == *Size1 && - OrigAlignment0 > *Size0 && SrcValOffset0 % *Size0 == 0 && - SrcValOffset1 % *Size1 == 0) { + Size0 && Size1 && *Size0 == *Size1 && OrigAlignment0 > *Size0 && + SrcValOffset0 % *Size0 == 0 && SrcValOffset1 % *Size1 == 0) { int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value(); int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value(); @@ -23886,8 +24534,8 @@ bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const { UseAA = false; #endif - if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() && - Size0.hasValue() && Size1.hasValue()) { + if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() && Size0 && + Size1) { // Use alias analysis information. int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1); int64_t Overlap0 = *Size0 + SrcValOffset0 - MinOffset; @@ -23920,7 +24568,7 @@ void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain, unsigned Depth = 0; // Attempt to improve chain by a single step - std::function<bool(SDValue &)> ImproveChain = [&](SDValue &C) -> bool { + auto ImproveChain = [&](SDValue &C) -> bool { switch (C.getOpcode()) { case ISD::EntryToken: // No need to mark EntryToken. |