diff options
Diffstat (limited to 'lib/CodeGen/SelectionDAG/TargetLowering.cpp')
-rw-r--r-- | lib/CodeGen/SelectionDAG/TargetLowering.cpp | 1723 |
1 files changed, 1300 insertions, 423 deletions
diff --git a/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/lib/CodeGen/SelectionDAG/TargetLowering.cpp index a2f05c1e3cef..b260cd91d468 100644 --- a/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1,9 +1,8 @@ //===-- TargetLowering.cpp - Implement the TargetLowering class -----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -100,19 +99,22 @@ bool TargetLowering::parametersInCSRMatch(const MachineRegisterInfo &MRI, /// Set CallLoweringInfo attribute flags based on a call instruction /// and called function attributes. -void TargetLoweringBase::ArgListEntry::setAttributes(ImmutableCallSite *CS, +void TargetLoweringBase::ArgListEntry::setAttributes(const CallBase *Call, unsigned ArgIdx) { - IsSExt = CS->paramHasAttr(ArgIdx, Attribute::SExt); - IsZExt = CS->paramHasAttr(ArgIdx, Attribute::ZExt); - IsInReg = CS->paramHasAttr(ArgIdx, Attribute::InReg); - IsSRet = CS->paramHasAttr(ArgIdx, Attribute::StructRet); - IsNest = CS->paramHasAttr(ArgIdx, Attribute::Nest); - IsByVal = CS->paramHasAttr(ArgIdx, Attribute::ByVal); - IsInAlloca = CS->paramHasAttr(ArgIdx, Attribute::InAlloca); - IsReturned = CS->paramHasAttr(ArgIdx, Attribute::Returned); - IsSwiftSelf = CS->paramHasAttr(ArgIdx, Attribute::SwiftSelf); - IsSwiftError = CS->paramHasAttr(ArgIdx, Attribute::SwiftError); - Alignment = CS->getParamAlignment(ArgIdx); + IsSExt = Call->paramHasAttr(ArgIdx, Attribute::SExt); + IsZExt = Call->paramHasAttr(ArgIdx, Attribute::ZExt); + IsInReg = Call->paramHasAttr(ArgIdx, Attribute::InReg); + IsSRet = Call->paramHasAttr(ArgIdx, Attribute::StructRet); + IsNest = Call->paramHasAttr(ArgIdx, Attribute::Nest); + IsByVal = Call->paramHasAttr(ArgIdx, Attribute::ByVal); + IsInAlloca = Call->paramHasAttr(ArgIdx, Attribute::InAlloca); + IsReturned = Call->paramHasAttr(ArgIdx, Attribute::Returned); + IsSwiftSelf = Call->paramHasAttr(ArgIdx, Attribute::SwiftSelf); + IsSwiftError = Call->paramHasAttr(ArgIdx, Attribute::SwiftError); + Alignment = Call->getParamAlignment(ArgIdx); + ByValType = nullptr; + if (Call->paramHasAttr(ArgIdx, Attribute::ByVal)) + ByValType = Call->getParamByValType(ArgIdx); } /// Generate a libcall taking the given operands as arguments and returning a @@ -121,7 +123,8 @@ std::pair<SDValue, SDValue> TargetLowering::makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC, EVT RetVT, ArrayRef<SDValue> Ops, bool isSigned, const SDLoc &dl, bool doesNotReturn, - bool isReturnValueUsed) const { + bool isReturnValueUsed, + bool isPostTypeLegalization) const { TargetLowering::ArgListTy Args; Args.reserve(Ops.size()); @@ -147,11 +150,114 @@ TargetLowering::makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC, EVT RetVT, .setLibCallee(getLibcallCallingConv(LC), RetTy, Callee, std::move(Args)) .setNoReturn(doesNotReturn) .setDiscardResult(!isReturnValueUsed) + .setIsPostTypeLegalization(isPostTypeLegalization) .setSExtResult(signExtend) .setZExtResult(!signExtend); return LowerCallTo(CLI); } +bool +TargetLowering::findOptimalMemOpLowering(std::vector<EVT> &MemOps, + unsigned Limit, uint64_t Size, + unsigned DstAlign, unsigned SrcAlign, + bool IsMemset, + bool ZeroMemset, + bool MemcpyStrSrc, + bool AllowOverlap, + unsigned DstAS, unsigned SrcAS, + const AttributeList &FuncAttributes) const { + // If 'SrcAlign' is zero, that means the memory operation does not need to + // load the value, i.e. memset or memcpy from constant string. Otherwise, + // it's the inferred alignment of the source. 'DstAlign', on the other hand, + // is the specified alignment of the memory operation. If it is zero, that + // means it's possible to change the alignment of the destination. + // 'MemcpyStrSrc' indicates whether the memcpy source is constant so it does + // not need to be loaded. + if (!(SrcAlign == 0 || SrcAlign >= DstAlign)) + return false; + + EVT VT = getOptimalMemOpType(Size, DstAlign, SrcAlign, + IsMemset, ZeroMemset, MemcpyStrSrc, + FuncAttributes); + + if (VT == MVT::Other) { + // Use the largest integer type whose alignment constraints are satisfied. + // We only need to check DstAlign here as SrcAlign is always greater or + // equal to DstAlign (or zero). + VT = MVT::i64; + while (DstAlign && DstAlign < VT.getSizeInBits() / 8 && + !allowsMisalignedMemoryAccesses(VT, DstAS, DstAlign)) + VT = (MVT::SimpleValueType)(VT.getSimpleVT().SimpleTy - 1); + assert(VT.isInteger()); + + // Find the largest legal integer type. + MVT LVT = MVT::i64; + while (!isTypeLegal(LVT)) + LVT = (MVT::SimpleValueType)(LVT.SimpleTy - 1); + assert(LVT.isInteger()); + + // If the type we've chosen is larger than the largest legal integer type + // then use that instead. + if (VT.bitsGT(LVT)) + VT = LVT; + } + + unsigned NumMemOps = 0; + while (Size != 0) { + unsigned VTSize = VT.getSizeInBits() / 8; + while (VTSize > Size) { + // For now, only use non-vector load / store's for the left-over pieces. + EVT NewVT = VT; + unsigned NewVTSize; + + bool Found = false; + if (VT.isVector() || VT.isFloatingPoint()) { + NewVT = (VT.getSizeInBits() > 64) ? MVT::i64 : MVT::i32; + if (isOperationLegalOrCustom(ISD::STORE, NewVT) && + isSafeMemOpType(NewVT.getSimpleVT())) + Found = true; + else if (NewVT == MVT::i64 && + isOperationLegalOrCustom(ISD::STORE, MVT::f64) && + isSafeMemOpType(MVT::f64)) { + // i64 is usually not legal on 32-bit targets, but f64 may be. + NewVT = MVT::f64; + Found = true; + } + } + + if (!Found) { + do { + NewVT = (MVT::SimpleValueType)(NewVT.getSimpleVT().SimpleTy - 1); + if (NewVT == MVT::i8) + break; + } while (!isSafeMemOpType(NewVT.getSimpleVT())); + } + NewVTSize = NewVT.getSizeInBits() / 8; + + // If the new VT cannot cover all of the remaining bits, then consider + // issuing a (or a pair of) unaligned and overlapping load / store. + bool Fast; + if (NumMemOps && AllowOverlap && NewVTSize < Size && + allowsMisalignedMemoryAccesses(VT, DstAS, DstAlign, + MachineMemOperand::MONone, &Fast) && + Fast) + VTSize = Size; + else { + VT = NewVT; + VTSize = NewVTSize; + } + } + + if (++NumMemOps > Limit) + return false; + + MemOps.push_back(VT); + Size -= VTSize; + } + + return true; +} + /// Soften the operands of a comparison. This code is shared among BR_CC, /// SELECT_CC, and SETCC handlers. void TargetLowering::softenSetCCOperands(SelectionDAG &DAG, EVT VT, @@ -346,7 +452,6 @@ TargetLowering::isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const { /// return true. bool TargetLowering::ShrinkDemandedConstant(SDValue Op, const APInt &Demanded, TargetLoweringOpt &TLO) const { - SelectionDAG &DAG = TLO.DAG; SDLoc DL(Op); unsigned Opcode = Op.getOpcode(); @@ -372,8 +477,8 @@ bool TargetLowering::ShrinkDemandedConstant(SDValue Op, const APInt &Demanded, if (!C.isSubsetOf(Demanded)) { EVT VT = Op.getValueType(); - SDValue NewC = DAG.getConstant(Demanded & C, DL, VT); - SDValue NewOp = DAG.getNode(Opcode, DL, VT, Op.getOperand(0), NewC); + SDValue NewC = TLO.DAG.getConstant(Demanded & C, DL, VT); + SDValue NewOp = TLO.DAG.getNode(Opcode, DL, VT, Op.getOperand(0), NewC); return TLO.CombineTo(Op, NewOp); } @@ -487,6 +592,10 @@ bool TargetLowering::SimplifyDemandedBits( // Don't know anything. Known = KnownBits(BitWidth); + // Undef operand. + if (Op.isUndef()) + return false; + if (Op.getOpcode() == ISD::Constant) { // We know all of the bits for a constant! Known.One = cast<ConstantSDNode>(Op)->getAPIntValue(); @@ -509,40 +618,116 @@ bool TargetLowering::SimplifyDemandedBits( DemandedElts = APInt::getAllOnesValue(NumElts); } else if (OriginalDemandedBits == 0 || OriginalDemandedElts == 0) { // Not demanding any bits/elts from Op. - if (!Op.isUndef()) - return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT)); - return false; + return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT)); } else if (Depth == 6) { // Limit search depth. return false; } KnownBits Known2, KnownOut; switch (Op.getOpcode()) { + case ISD::SCALAR_TO_VECTOR: { + if (!DemandedElts[0]) + return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT)); + + KnownBits SrcKnown; + SDValue Src = Op.getOperand(0); + unsigned SrcBitWidth = Src.getScalarValueSizeInBits(); + APInt SrcDemandedBits = DemandedBits.zextOrSelf(SrcBitWidth); + if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcKnown, TLO, Depth + 1)) + return true; + Known = SrcKnown.zextOrTrunc(BitWidth, false); + break; + } case ISD::BUILD_VECTOR: - // Collect the known bits that are shared by every constant vector element. - Known.Zero.setAllBits(); Known.One.setAllBits(); - for (SDValue SrcOp : Op->ops()) { - if (!isa<ConstantSDNode>(SrcOp)) { - // We can only handle all constant values - bail out with no known bits. - Known = KnownBits(BitWidth); - return false; - } - Known2.One = cast<ConstantSDNode>(SrcOp)->getAPIntValue(); - Known2.Zero = ~Known2.One; - - // BUILD_VECTOR can implicitly truncate sources, we must handle this. - if (Known2.One.getBitWidth() != BitWidth) { - assert(Known2.getBitWidth() > BitWidth && - "Expected BUILD_VECTOR implicit truncation"); - Known2 = Known2.trunc(BitWidth); + // Collect the known bits that are shared by every demanded element. + // TODO: Call SimplifyDemandedBits for non-constant demanded elements. + Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth); + return false; // Don't fall through, will infinitely loop. + case ISD::LOAD: { + LoadSDNode *LD = cast<LoadSDNode>(Op); + if (getTargetConstantFromLoad(LD)) { + Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth); + return false; // Don't fall through, will infinitely loop. + } + break; + } + case ISD::INSERT_VECTOR_ELT: { + SDValue Vec = Op.getOperand(0); + SDValue Scl = Op.getOperand(1); + auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2)); + EVT VecVT = Vec.getValueType(); + + // If index isn't constant, assume we need all vector elements AND the + // inserted element. + APInt DemandedVecElts(DemandedElts); + if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements())) { + unsigned Idx = CIdx->getZExtValue(); + DemandedVecElts.clearBit(Idx); + + // Inserted element is not required. + if (!DemandedElts[Idx]) + return TLO.CombineTo(Op, Vec); + } + + KnownBits KnownScl; + unsigned NumSclBits = Scl.getScalarValueSizeInBits(); + APInt DemandedSclBits = DemandedBits.zextOrTrunc(NumSclBits); + if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1)) + return true; + + Known = KnownScl.zextOrTrunc(BitWidth, false); + + KnownBits KnownVec; + if (SimplifyDemandedBits(Vec, DemandedBits, DemandedVecElts, KnownVec, TLO, + Depth + 1)) + return true; + + if (!!DemandedVecElts) { + Known.One &= KnownVec.One; + Known.Zero &= KnownVec.Zero; + } + + return false; + } + case ISD::INSERT_SUBVECTOR: { + SDValue Base = Op.getOperand(0); + SDValue Sub = Op.getOperand(1); + EVT SubVT = Sub.getValueType(); + unsigned NumSubElts = SubVT.getVectorNumElements(); + + // If index isn't constant, assume we need the original demanded base + // elements and ALL the inserted subvector elements. + APInt BaseElts = DemandedElts; + APInt SubElts = APInt::getAllOnesValue(NumSubElts); + if (isa<ConstantSDNode>(Op.getOperand(2))) { + const APInt &Idx = Op.getConstantOperandAPInt(2); + if (Idx.ule(NumElts - NumSubElts)) { + unsigned SubIdx = Idx.getZExtValue(); + SubElts = DemandedElts.extractBits(NumSubElts, SubIdx); + BaseElts.insertBits(APInt::getNullValue(NumSubElts), SubIdx); } + } - // Known bits are the values that are shared by every element. - // TODO: support per-element known bits. - Known.One &= Known2.One; - Known.Zero &= Known2.Zero; + KnownBits KnownSub, KnownBase; + if (SimplifyDemandedBits(Sub, DemandedBits, SubElts, KnownSub, TLO, + Depth + 1)) + return true; + if (SimplifyDemandedBits(Base, DemandedBits, BaseElts, KnownBase, TLO, + Depth + 1)) + return true; + + Known.Zero.setAllBits(); + Known.One.setAllBits(); + if (!!SubElts) { + Known.One &= KnownSub.One; + Known.Zero &= KnownSub.Zero; } - return false; // Don't fall through, will infinitely loop. + if (!!BaseElts) { + Known.One &= KnownBase.One; + Known.Zero &= KnownBase.Zero; + } + break; + } case ISD::CONCAT_VECTORS: { Known.Zero.setAllBits(); Known.One.setAllBits(); @@ -640,11 +825,12 @@ bool TargetLowering::SimplifyDemandedBits( } } - if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, + Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - if (SimplifyDemandedBits(Op0, ~Known.Zero & DemandedBits, DemandedElts, Known2, TLO, - Depth + 1)) + if (SimplifyDemandedBits(Op0, ~Known.Zero & DemandedBits, DemandedElts, + Known2, TLO, Depth + 1)) return true; assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); @@ -674,11 +860,12 @@ bool TargetLowering::SimplifyDemandedBits( SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); - if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, + Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - if (SimplifyDemandedBits(Op0, ~Known.One & DemandedBits, DemandedElts, Known2, TLO, - Depth + 1)) + if (SimplifyDemandedBits(Op0, ~Known.One & DemandedBits, DemandedElts, + Known2, TLO, Depth + 1)) return true; assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); @@ -705,10 +892,12 @@ bool TargetLowering::SimplifyDemandedBits( SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); - if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, + Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - if (SimplifyDemandedBits(Op0, DemandedBits, DemandedElts, Known2, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op0, DemandedBits, DemandedElts, Known2, TLO, + Depth + 1)) return true; assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); @@ -831,20 +1020,23 @@ bool TargetLowering::SimplifyDemandedBits( SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); - if (ConstantSDNode *SA = isConstOrConstSplat(Op1)) { + if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) { // If the shift count is an invalid immediate, don't do anything. if (SA->getAPIntValue().uge(BitWidth)) break; unsigned ShAmt = SA->getZExtValue(); + if (ShAmt == 0) + return TLO.CombineTo(Op, Op0); // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a // single shift. We can do this if the bottom bits (which are shifted // out) are never demanded. + // TODO - support non-uniform vector amounts. if (Op0.getOpcode() == ISD::SRL) { - if (ShAmt && - (DemandedBits & APInt::getLowBitsSet(BitWidth, ShAmt)) == 0) { - if (ConstantSDNode *SA2 = isConstOrConstSplat(Op0.getOperand(1))) { + if ((DemandedBits & APInt::getLowBitsSet(BitWidth, ShAmt)) == 0) { + if (ConstantSDNode *SA2 = + isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) { if (SA2->getAPIntValue().ult(BitWidth)) { unsigned C1 = SA2->getZExtValue(); unsigned Opc = ISD::SHL; @@ -862,8 +1054,14 @@ bool TargetLowering::SimplifyDemandedBits( } } - if (SimplifyDemandedBits(Op0, DemandedBits.lshr(ShAmt), DemandedElts, Known, TLO, - Depth + 1)) + if (SimplifyDemandedBits(Op0, DemandedBits.lshr(ShAmt), DemandedElts, + Known, TLO, Depth + 1)) + return true; + + // Try shrinking the operation as long as the shift amount will still be + // in range. + if ((ShAmt < DemandedBits.getActiveBits()) && + ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) return true; // Convert (shl (anyext x, c)) to (anyext (shl x, c)) if the high bits @@ -919,12 +1117,16 @@ bool TargetLowering::SimplifyDemandedBits( SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); - if (ConstantSDNode *SA = isConstOrConstSplat(Op1)) { + if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) { // If the shift count is an invalid immediate, don't do anything. if (SA->getAPIntValue().uge(BitWidth)) break; unsigned ShAmt = SA->getZExtValue(); + if (ShAmt == 0) + return TLO.CombineTo(Op, Op0); + + EVT ShiftVT = Op1.getValueType(); APInt InDemandedMask = (DemandedBits << ShAmt); // If the shift is exact, then it does demand the low bits (and knows that @@ -935,10 +1137,11 @@ bool TargetLowering::SimplifyDemandedBits( // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a // single shift. We can do this if the top bits (which are shifted out) // are never demanded. + // TODO - support non-uniform vector amounts. if (Op0.getOpcode() == ISD::SHL) { - if (ConstantSDNode *SA2 = isConstOrConstSplat(Op0.getOperand(1))) { - if (ShAmt && - (DemandedBits & APInt::getHighBitsSet(BitWidth, ShAmt)) == 0) { + if (ConstantSDNode *SA2 = + isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) { + if ((DemandedBits & APInt::getHighBitsSet(BitWidth, ShAmt)) == 0) { if (SA2->getAPIntValue().ult(BitWidth)) { unsigned C1 = SA2->getZExtValue(); unsigned Opc = ISD::SRL; @@ -948,7 +1151,7 @@ bool TargetLowering::SimplifyDemandedBits( Opc = ISD::SHL; } - SDValue NewSA = TLO.DAG.getConstant(Diff, dl, Op1.getValueType()); + SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT); return TLO.CombineTo( Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA)); } @@ -957,7 +1160,8 @@ bool TargetLowering::SimplifyDemandedBits( } // Compute the new bits that are at the top now. - if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, + Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero.lshrInPlace(ShAmt); @@ -978,12 +1182,15 @@ bool TargetLowering::SimplifyDemandedBits( if (DemandedBits.isOneValue()) return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1)); - if (ConstantSDNode *SA = isConstOrConstSplat(Op1)) { + if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) { // If the shift count is an invalid immediate, don't do anything. if (SA->getAPIntValue().uge(BitWidth)) break; unsigned ShAmt = SA->getZExtValue(); + if (ShAmt == 0) + return TLO.CombineTo(Op, Op0); + APInt InDemandedMask = (DemandedBits << ShAmt); // If the shift is exact, then it does demand the low bits (and knows that @@ -996,7 +1203,8 @@ bool TargetLowering::SimplifyDemandedBits( if (DemandedBits.countLeadingZeros() < ShAmt) InDemandedMask.setSignBit(); - if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, + Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero.lshrInPlace(ShAmt); @@ -1026,6 +1234,55 @@ bool TargetLowering::SimplifyDemandedBits( } break; } + case ISD::FSHL: + case ISD::FSHR: { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + SDValue Op2 = Op.getOperand(2); + bool IsFSHL = (Op.getOpcode() == ISD::FSHL); + + if (ConstantSDNode *SA = isConstOrConstSplat(Op2, DemandedElts)) { + unsigned Amt = SA->getAPIntValue().urem(BitWidth); + + // For fshl, 0-shift returns the 1st arg. + // For fshr, 0-shift returns the 2nd arg. + if (Amt == 0) { + if (SimplifyDemandedBits(IsFSHL ? Op0 : Op1, DemandedBits, DemandedElts, + Known, TLO, Depth + 1)) + return true; + break; + } + + // fshl: (Op0 << Amt) | (Op1 >> (BW - Amt)) + // fshr: (Op0 << (BW - Amt)) | (Op1 >> Amt) + APInt Demanded0 = DemandedBits.lshr(IsFSHL ? Amt : (BitWidth - Amt)); + APInt Demanded1 = DemandedBits << (IsFSHL ? (BitWidth - Amt) : Amt); + if (SimplifyDemandedBits(Op0, Demanded0, DemandedElts, Known2, TLO, + Depth + 1)) + return true; + if (SimplifyDemandedBits(Op1, Demanded1, DemandedElts, Known, TLO, + Depth + 1)) + return true; + + Known2.One <<= (IsFSHL ? Amt : (BitWidth - Amt)); + Known2.Zero <<= (IsFSHL ? Amt : (BitWidth - Amt)); + Known.One.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt); + Known.Zero.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt); + Known.One |= Known2.One; + Known.Zero |= Known2.Zero; + } + break; + } + case ISD::BITREVERSE: { + SDValue Src = Op.getOperand(0); + APInt DemandedSrcBits = DemandedBits.reverseBits(); + if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO, + Depth + 1)) + return true; + Known.One = Known2.One.reverseBits(); + Known.Zero = Known2.Zero.reverseBits(); + break; + } case ISD::SIGN_EXTEND_INREG: { SDValue Op0 = Op.getOperand(0); EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT(); @@ -1033,8 +1290,8 @@ bool TargetLowering::SimplifyDemandedBits( // If we only care about the highest bit, don't bother shifting right. if (DemandedBits.isSignMask()) { - bool AlreadySignExtended = - TLO.DAG.ComputeNumSignBits(Op0) >= BitWidth - ExVTBits + 1; + unsigned NumSignBits = TLO.DAG.ComputeNumSignBits(Op0); + bool AlreadySignExtended = NumSignBits >= BitWidth - ExVTBits + 1; // However if the input is already sign extended we expect the sign // extension to be dropped altogether later and do not simplify. if (!AlreadySignExtended) { @@ -1099,79 +1356,116 @@ bool TargetLowering::SimplifyDemandedBits( return true; Known.Zero = KnownLo.Zero.zext(BitWidth) | - KnownHi.Zero.zext(BitWidth).shl(HalfBitWidth); + KnownHi.Zero.zext(BitWidth).shl(HalfBitWidth); Known.One = KnownLo.One.zext(BitWidth) | - KnownHi.One.zext(BitWidth).shl(HalfBitWidth); + KnownHi.One.zext(BitWidth).shl(HalfBitWidth); break; } - case ISD::ZERO_EXTEND: { + case ISD::ZERO_EXTEND: + case ISD::ZERO_EXTEND_VECTOR_INREG: { SDValue Src = Op.getOperand(0); - unsigned InBits = Src.getScalarValueSizeInBits(); + EVT SrcVT = Src.getValueType(); + unsigned InBits = SrcVT.getScalarSizeInBits(); + unsigned InElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1; + bool IsVecInReg = Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG; // If none of the top bits are demanded, convert this into an any_extend. - if (DemandedBits.getActiveBits() <= InBits) - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, Src)); + if (DemandedBits.getActiveBits() <= InBits) { + // If we only need the non-extended bits of the bottom element + // then we can just bitcast to the result. + if (IsVecInReg && DemandedElts == 1 && + VT.getSizeInBits() == SrcVT.getSizeInBits() && + TLO.DAG.getDataLayout().isLittleEndian()) + return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src)); + + unsigned Opc = + IsVecInReg ? ISD::ANY_EXTEND_VECTOR_INREG : ISD::ANY_EXTEND; + if (!TLO.LegalOperations() || isOperationLegal(Opc, VT)) + return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src)); + } APInt InDemandedBits = DemandedBits.trunc(InBits); - if (SimplifyDemandedBits(Src, InDemandedBits, Known, TLO, Depth+1)) + APInt InDemandedElts = DemandedElts.zextOrSelf(InElts); + if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO, + Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - Known = Known.zext(BitWidth); - Known.Zero.setBitsFrom(InBits); + assert(Known.getBitWidth() == InBits && "Src width has changed?"); + Known = Known.zext(BitWidth, true /* ExtendedBitsAreKnownZero */); break; } - case ISD::SIGN_EXTEND: { + case ISD::SIGN_EXTEND: + case ISD::SIGN_EXTEND_VECTOR_INREG: { SDValue Src = Op.getOperand(0); - unsigned InBits = Src.getScalarValueSizeInBits(); + EVT SrcVT = Src.getValueType(); + unsigned InBits = SrcVT.getScalarSizeInBits(); + unsigned InElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1; + bool IsVecInReg = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG; // If none of the top bits are demanded, convert this into an any_extend. - if (DemandedBits.getActiveBits() <= InBits) - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, Src)); + if (DemandedBits.getActiveBits() <= InBits) { + // If we only need the non-extended bits of the bottom element + // then we can just bitcast to the result. + if (IsVecInReg && DemandedElts == 1 && + VT.getSizeInBits() == SrcVT.getSizeInBits() && + TLO.DAG.getDataLayout().isLittleEndian()) + return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src)); + + unsigned Opc = + IsVecInReg ? ISD::ANY_EXTEND_VECTOR_INREG : ISD::ANY_EXTEND; + if (!TLO.LegalOperations() || isOperationLegal(Opc, VT)) + return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src)); + } + + APInt InDemandedBits = DemandedBits.trunc(InBits); + APInt InDemandedElts = DemandedElts.zextOrSelf(InElts); // Since some of the sign extended bits are demanded, we know that the sign // bit is demanded. - APInt InDemandedBits = DemandedBits.trunc(InBits); InDemandedBits.setBit(InBits - 1); - if (SimplifyDemandedBits(Src, InDemandedBits, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO, + Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + assert(Known.getBitWidth() == InBits && "Src width has changed?"); + // If the sign bit is known one, the top bits match. Known = Known.sext(BitWidth); // If the sign bit is known zero, convert this to a zero extend. - if (Known.isNonNegative()) - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Src)); + if (Known.isNonNegative()) { + unsigned Opc = + IsVecInReg ? ISD::ZERO_EXTEND_VECTOR_INREG : ISD::ZERO_EXTEND; + if (!TLO.LegalOperations() || isOperationLegal(Opc, VT)) + return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src)); + } break; } - case ISD::SIGN_EXTEND_VECTOR_INREG: { - // TODO - merge this with SIGN_EXTEND above? + case ISD::ANY_EXTEND: + case ISD::ANY_EXTEND_VECTOR_INREG: { SDValue Src = Op.getOperand(0); - unsigned InBits = Src.getScalarValueSizeInBits(); - - APInt InDemandedBits = DemandedBits.trunc(InBits); + EVT SrcVT = Src.getValueType(); + unsigned InBits = SrcVT.getScalarSizeInBits(); + unsigned InElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1; + bool IsVecInReg = Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG; - // If some of the sign extended bits are demanded, we know that the sign - // bit is demanded. - if (InBits < DemandedBits.getActiveBits()) - InDemandedBits.setBit(InBits - 1); + // If we only need the bottom element then we can just bitcast. + // TODO: Handle ANY_EXTEND? + if (IsVecInReg && DemandedElts == 1 && + VT.getSizeInBits() == SrcVT.getSizeInBits() && + TLO.DAG.getDataLayout().isLittleEndian()) + return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src)); - if (SimplifyDemandedBits(Src, InDemandedBits, Known, TLO, Depth + 1)) - return true; - assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - // If the sign bit is known one, the top bits match. - Known = Known.sext(BitWidth); - break; - } - case ISD::ANY_EXTEND: { - SDValue Src = Op.getOperand(0); - unsigned InBits = Src.getScalarValueSizeInBits(); APInt InDemandedBits = DemandedBits.trunc(InBits); - if (SimplifyDemandedBits(Src, InDemandedBits, Known, TLO, Depth+1)) + APInt InDemandedElts = DemandedElts.zextOrSelf(InElts); + if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO, + Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - Known = Known.zext(BitWidth); + assert(Known.getBitWidth() == InBits && "Src width has changed?"); + Known = Known.zext(BitWidth, false /* => any extend */); break; } case ISD::TRUNCATE: { @@ -1198,29 +1492,29 @@ bool TargetLowering::SimplifyDemandedBits( // Do not turn (vt1 truncate (vt2 srl)) into (vt1 srl) if vt1 is // undesirable. break; - ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Src.getOperand(1)); - if (!ShAmt) + + auto *ShAmt = dyn_cast<ConstantSDNode>(Src.getOperand(1)); + if (!ShAmt || ShAmt->getAPIntValue().uge(BitWidth)) break; + SDValue Shift = Src.getOperand(1); - if (TLO.LegalTypes()) { - uint64_t ShVal = ShAmt->getZExtValue(); + uint64_t ShVal = ShAmt->getZExtValue(); + + if (TLO.LegalTypes()) Shift = TLO.DAG.getConstant(ShVal, dl, getShiftAmountTy(VT, DL)); - } - if (ShAmt->getZExtValue() < BitWidth) { - APInt HighBits = APInt::getHighBitsSet(OperandBitWidth, - OperandBitWidth - BitWidth); - HighBits.lshrInPlace(ShAmt->getZExtValue()); - HighBits = HighBits.trunc(BitWidth); - - if (!(HighBits & DemandedBits)) { - // None of the shifted in bits are needed. Add a truncate of the - // shift input, then shift it. - SDValue NewTrunc = - TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, Src.getOperand(0)); - return TLO.CombineTo( - Op, TLO.DAG.getNode(ISD::SRL, dl, VT, NewTrunc, Shift)); - } + APInt HighBits = + APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth); + HighBits.lshrInPlace(ShVal); + HighBits = HighBits.trunc(BitWidth); + + if (!(HighBits & DemandedBits)) { + // None of the shifted in bits are needed. Add a truncate of the + // shift input, then shift it. + SDValue NewTrunc = + TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, Src.getOperand(0)); + return TLO.CombineTo( + Op, TLO.DAG.getNode(ISD::SRL, dl, VT, NewTrunc, Shift)); } break; } @@ -1234,8 +1528,8 @@ bool TargetLowering::SimplifyDemandedBits( // demanded by its users. EVT ZVT = cast<VTSDNode>(Op.getOperand(1))->getVT(); APInt InMask = APInt::getLowBitsSet(BitWidth, ZVT.getSizeInBits()); - if (SimplifyDemandedBits(Op.getOperand(0), ~InMask | DemandedBits, - Known, TLO, Depth+1)) + if (SimplifyDemandedBits(Op.getOperand(0), ~InMask | DemandedBits, Known, + TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); @@ -1266,7 +1560,7 @@ bool TargetLowering::SimplifyDemandedBits( Known = Known2; if (BitWidth > EltBitWidth) - Known = Known.zext(BitWidth); + Known = Known.zext(BitWidth, false /* => any extend */); break; } case ISD::BITCAST: { @@ -1297,40 +1591,68 @@ bool TargetLowering::SimplifyDemandedBits( TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt)); } } - // If bitcast from a vector, see if we can use SimplifyDemandedVectorElts by - // demanding the element if any bits from it are demanded. + + // Bitcast from a vector using SimplifyDemanded Bits/VectorElts. + // Demand the elt/bit if any of the original elts/bits are demanded. // TODO - bigendian once we have test coverage. // TODO - bool vectors once SimplifyDemandedVectorElts has SETCC support. if (SrcVT.isVector() && NumSrcEltBits > 1 && (BitWidth % NumSrcEltBits) == 0 && TLO.DAG.getDataLayout().isLittleEndian()) { unsigned Scale = BitWidth / NumSrcEltBits; - auto GetDemandedSubMask = [&](APInt &DemandedSubElts) -> bool { - DemandedSubElts = APInt::getNullValue(Scale); - for (unsigned i = 0; i != Scale; ++i) { - unsigned Offset = i * NumSrcEltBits; - APInt Sub = DemandedBits.extractBits(NumSrcEltBits, Offset); - if (!Sub.isNullValue()) - DemandedSubElts.setBit(i); + unsigned NumSrcElts = SrcVT.getVectorNumElements(); + APInt DemandedSrcBits = APInt::getNullValue(NumSrcEltBits); + APInt DemandedSrcElts = APInt::getNullValue(NumSrcElts); + for (unsigned i = 0; i != Scale; ++i) { + unsigned Offset = i * NumSrcEltBits; + APInt Sub = DemandedBits.extractBits(NumSrcEltBits, Offset); + if (!Sub.isNullValue()) { + DemandedSrcBits |= Sub; + for (unsigned j = 0; j != NumElts; ++j) + if (DemandedElts[j]) + DemandedSrcElts.setBit((j * Scale) + i); } + } + + APInt KnownSrcUndef, KnownSrcZero; + if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef, + KnownSrcZero, TLO, Depth + 1)) return true; - }; - APInt DemandedSubElts; - if (GetDemandedSubMask(DemandedSubElts)) { - unsigned NumSrcElts = SrcVT.getVectorNumElements(); - APInt DemandedElts = APInt::getSplat(NumSrcElts, DemandedSubElts); + KnownBits KnownSrcBits; + if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts, + KnownSrcBits, TLO, Depth + 1)) + return true; + } else if ((NumSrcEltBits % BitWidth) == 0 && + TLO.DAG.getDataLayout().isLittleEndian()) { + unsigned Scale = NumSrcEltBits / BitWidth; + unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1; + APInt DemandedSrcBits = APInt::getNullValue(NumSrcEltBits); + APInt DemandedSrcElts = APInt::getNullValue(NumSrcElts); + for (unsigned i = 0; i != NumElts; ++i) + if (DemandedElts[i]) { + unsigned Offset = (i % Scale) * BitWidth; + DemandedSrcBits.insertBits(DemandedBits, Offset); + DemandedSrcElts.setBit(i / Scale); + } - APInt KnownUndef, KnownZero; - if (SimplifyDemandedVectorElts(Src, DemandedElts, KnownUndef, KnownZero, - TLO, Depth + 1)) + if (SrcVT.isVector()) { + APInt KnownSrcUndef, KnownSrcZero; + if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef, + KnownSrcZero, TLO, Depth + 1)) return true; } + + KnownBits KnownSrcBits; + if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts, + KnownSrcBits, TLO, Depth + 1)) + return true; } + // If this is a bitcast, let computeKnownBits handle it. Only do this on a // recursive call where Known may be useful to the caller. if (Depth > 0) { - Known = TLO.DAG.computeKnownBits(Op, Depth); + Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth); return false; } break; @@ -1343,8 +1665,10 @@ bool TargetLowering::SimplifyDemandedBits( SDValue Op0 = Op.getOperand(0), Op1 = Op.getOperand(1); unsigned DemandedBitsLZ = DemandedBits.countLeadingZeros(); APInt LoMask = APInt::getLowBitsSet(BitWidth, BitWidth - DemandedBitsLZ); - if (SimplifyDemandedBits(Op0, LoMask, DemandedElts, Known2, TLO, Depth + 1) || - SimplifyDemandedBits(Op1, LoMask, DemandedElts, Known2, TLO, Depth + 1) || + if (SimplifyDemandedBits(Op0, LoMask, DemandedElts, Known2, TLO, + Depth + 1) || + SimplifyDemandedBits(Op1, LoMask, DemandedElts, Known2, TLO, + Depth + 1) || // See if the operation should be performed at a smaller bit width. ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) { SDNodeFlags Flags = Op.getNode()->getFlags(); @@ -1353,8 +1677,8 @@ bool TargetLowering::SimplifyDemandedBits( // won't wrap after simplification. Flags.setNoSignedWrap(false); Flags.setNoUnsignedWrap(false); - SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1, - Flags); + SDValue NewOp = + TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1, Flags); return TLO.CombineTo(Op, NewOp); } return true; @@ -1431,15 +1755,64 @@ bool TargetLowering::SimplifyDemandedVectorElts(SDValue Op, DCI.AddToWorklist(Op.getNode()); DCI.CommitTargetLoweringOpt(TLO); } + return Simplified; } +/// Given a vector binary operation and known undefined elements for each input +/// operand, compute whether each element of the output is undefined. +static APInt getKnownUndefForVectorBinop(SDValue BO, SelectionDAG &DAG, + const APInt &UndefOp0, + const APInt &UndefOp1) { + EVT VT = BO.getValueType(); + assert(DAG.getTargetLoweringInfo().isBinOp(BO.getOpcode()) && VT.isVector() && + "Vector binop only"); + + EVT EltVT = VT.getVectorElementType(); + unsigned NumElts = VT.getVectorNumElements(); + assert(UndefOp0.getBitWidth() == NumElts && + UndefOp1.getBitWidth() == NumElts && "Bad type for undef analysis"); + + auto getUndefOrConstantElt = [&](SDValue V, unsigned Index, + const APInt &UndefVals) { + if (UndefVals[Index]) + return DAG.getUNDEF(EltVT); + + if (auto *BV = dyn_cast<BuildVectorSDNode>(V)) { + // Try hard to make sure that the getNode() call is not creating temporary + // nodes. Ignore opaque integers because they do not constant fold. + SDValue Elt = BV->getOperand(Index); + auto *C = dyn_cast<ConstantSDNode>(Elt); + if (isa<ConstantFPSDNode>(Elt) || Elt.isUndef() || (C && !C->isOpaque())) + return Elt; + } + + return SDValue(); + }; + + APInt KnownUndef = APInt::getNullValue(NumElts); + for (unsigned i = 0; i != NumElts; ++i) { + // If both inputs for this element are either constant or undef and match + // the element type, compute the constant/undef result for this element of + // the vector. + // TODO: Ideally we would use FoldConstantArithmetic() here, but that does + // not handle FP constants. The code within getNode() should be refactored + // to avoid the danger of creating a bogus temporary node here. + SDValue C0 = getUndefOrConstantElt(BO.getOperand(0), i, UndefOp0); + SDValue C1 = getUndefOrConstantElt(BO.getOperand(1), i, UndefOp1); + if (C0 && C1 && C0.getValueType() == EltVT && C1.getValueType() == EltVT) + if (DAG.getNode(BO.getOpcode(), SDLoc(BO), EltVT, C0, C1).isUndef()) + KnownUndef.setBit(i); + } + return KnownUndef; +} + bool TargetLowering::SimplifyDemandedVectorElts( - SDValue Op, const APInt &DemandedEltMask, APInt &KnownUndef, + SDValue Op, const APInt &OriginalDemandedElts, APInt &KnownUndef, APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth, bool AssumeSingleUse) const { EVT VT = Op.getValueType(); - APInt DemandedElts = DemandedEltMask; + APInt DemandedElts = OriginalDemandedElts; unsigned NumElts = DemandedElts.getBitWidth(); assert(VT.isVector() && "Expected vector op"); assert(VT.getVectorNumElements() == NumElts && @@ -1617,7 +1990,7 @@ bool TargetLowering::SimplifyDemandedVectorElts( SDValue Sub = Op.getOperand(1); EVT SubVT = Sub.getValueType(); unsigned NumSubElts = SubVT.getVectorNumElements(); - const APInt& Idx = cast<ConstantSDNode>(Op.getOperand(2))->getAPIntValue(); + const APInt &Idx = Op.getConstantOperandAPInt(2); if (Idx.ugt(NumElts - NumSubElts)) break; unsigned SubIdx = Idx.getZExtValue(); @@ -1786,18 +2159,26 @@ bool TargetLowering::SimplifyDemandedVectorElts( } break; } + case ISD::ANY_EXTEND_VECTOR_INREG: case ISD::SIGN_EXTEND_VECTOR_INREG: case ISD::ZERO_EXTEND_VECTOR_INREG: { APInt SrcUndef, SrcZero; SDValue Src = Op.getOperand(0); unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); APInt DemandedSrcElts = DemandedElts.zextOrSelf(NumSrcElts); - if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, - SrcZero, TLO, Depth + 1)) + if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, SrcZero, TLO, + Depth + 1)) return true; KnownZero = SrcZero.zextOrTrunc(NumElts); KnownUndef = SrcUndef.zextOrTrunc(NumElts); + if (Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG && + Op.getValueSizeInBits() == Src.getValueSizeInBits() && + DemandedSrcElts == 1 && TLO.DAG.getDataLayout().isLittleEndian()) { + // aext - if we just need the bottom element then we can bitcast. + return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src)); + } + if (Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) { // zext(undef) upper bits are guaranteed to be zero. if (DemandedElts.isSubsetOf(KnownUndef)) @@ -1806,6 +2187,9 @@ bool TargetLowering::SimplifyDemandedVectorElts( } break; } + + // TODO: There are more binop opcodes that could be handled here - MUL, MIN, + // MAX, saturated math, etc. case ISD::OR: case ISD::XOR: case ISD::ADD: @@ -1815,17 +2199,38 @@ bool TargetLowering::SimplifyDemandedVectorElts( case ISD::FMUL: case ISD::FDIV: case ISD::FREM: { - APInt SrcUndef, SrcZero; - if (SimplifyDemandedVectorElts(Op.getOperand(1), DemandedElts, SrcUndef, - SrcZero, TLO, Depth + 1)) + APInt UndefRHS, ZeroRHS; + if (SimplifyDemandedVectorElts(Op.getOperand(1), DemandedElts, UndefRHS, + ZeroRHS, TLO, Depth + 1)) return true; - if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, KnownUndef, - KnownZero, TLO, Depth + 1)) + APInt UndefLHS, ZeroLHS; + if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, UndefLHS, + ZeroLHS, TLO, Depth + 1)) return true; - KnownZero &= SrcZero; - KnownUndef &= SrcUndef; + + KnownZero = ZeroLHS & ZeroRHS; + KnownUndef = getKnownUndefForVectorBinop(Op, TLO.DAG, UndefLHS, UndefRHS); + break; + } + case ISD::SHL: + case ISD::SRL: + case ISD::SRA: + case ISD::ROTL: + case ISD::ROTR: { + APInt UndefRHS, ZeroRHS; + if (SimplifyDemandedVectorElts(Op.getOperand(1), DemandedElts, UndefRHS, + ZeroRHS, TLO, Depth + 1)) + return true; + APInt UndefLHS, ZeroLHS; + if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, UndefLHS, + ZeroLHS, TLO, Depth + 1)) + return true; + + KnownZero = ZeroLHS; + KnownUndef = UndefLHS & UndefRHS; // TODO: use getKnownUndefForVectorBinop? break; } + case ISD::MUL: case ISD::AND: { APInt SrcUndef, SrcZero; if (SimplifyDemandedVectorElts(Op.getOperand(1), DemandedElts, SrcUndef, @@ -1837,6 +2242,8 @@ bool TargetLowering::SimplifyDemandedVectorElts( // If either side has a zero element, then the result element is zero, even // if the other is an UNDEF. + // TODO: Extend getKnownUndefForVectorBinop to also deal with known zeros + // and then handle 'and' nodes with the rest of the binop opcodes. KnownZero |= SrcZero; KnownUndef &= SrcUndef; KnownUndef &= ~KnownZero; @@ -1864,8 +2271,8 @@ bool TargetLowering::SimplifyDemandedVectorElts( } else { KnownBits Known; APInt DemandedBits = APInt::getAllOnesValue(EltSizeInBits); - if (SimplifyDemandedBits(Op, DemandedBits, DemandedEltMask, Known, TLO, - Depth, AssumeSingleUse)) + if (SimplifyDemandedBits(Op, DemandedBits, OriginalDemandedElts, Known, + TLO, Depth, AssumeSingleUse)) return true; } break; @@ -1950,6 +2357,10 @@ bool TargetLowering::SimplifyDemandedBitsForTargetNode( return false; } +const Constant *TargetLowering::getTargetConstantFromLoad(LoadSDNode*) const { + return nullptr; +} + bool TargetLowering::isKnownNeverNaNForTargetNode(SDValue Op, const SelectionDAG &DAG, bool SNaN, @@ -2044,10 +2455,9 @@ bool TargetLowering::isExtendedTrueVal(const ConstantSDNode *N, EVT VT, /// This helper function of SimplifySetCC tries to optimize the comparison when /// either operand of the SetCC node is a bitwise-and instruction. -SDValue TargetLowering::simplifySetCCWithAnd(EVT VT, SDValue N0, SDValue N1, - ISD::CondCode Cond, - DAGCombinerInfo &DCI, - const SDLoc &DL) const { +SDValue TargetLowering::foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, + ISD::CondCode Cond, const SDLoc &DL, + DAGCombinerInfo &DCI) const { // Match these patterns in any of their permutations: // (X & Y) == Y // (X & Y) != Y @@ -2200,6 +2610,49 @@ SDValue TargetLowering::optimizeSetCCOfSignedTruncationCheck( return T2; } +/// Try to fold an equality comparison with a {add/sub/xor} binary operation as +/// the 1st operand (N0). Callers are expected to swap the N0/N1 parameters to +/// handle the commuted versions of these patterns. +SDValue TargetLowering::foldSetCCWithBinOp(EVT VT, SDValue N0, SDValue N1, + ISD::CondCode Cond, const SDLoc &DL, + DAGCombinerInfo &DCI) const { + unsigned BOpcode = N0.getOpcode(); + assert((BOpcode == ISD::ADD || BOpcode == ISD::SUB || BOpcode == ISD::XOR) && + "Unexpected binop"); + assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) && "Unexpected condcode"); + + // (X + Y) == X --> Y == 0 + // (X - Y) == X --> Y == 0 + // (X ^ Y) == X --> Y == 0 + SelectionDAG &DAG = DCI.DAG; + EVT OpVT = N0.getValueType(); + SDValue X = N0.getOperand(0); + SDValue Y = N0.getOperand(1); + if (X == N1) + return DAG.getSetCC(DL, VT, Y, DAG.getConstant(0, DL, OpVT), Cond); + + if (Y != N1) + return SDValue(); + + // (X + Y) == Y --> X == 0 + // (X ^ Y) == Y --> X == 0 + if (BOpcode == ISD::ADD || BOpcode == ISD::XOR) + return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, OpVT), Cond); + + // The shift would not be valid if the operands are boolean (i1). + if (!N0.hasOneUse() || OpVT.getScalarSizeInBits() == 1) + return SDValue(); + + // (X - Y) == Y --> X == Y << 1 + EVT ShiftVT = getShiftAmountTy(OpVT, DAG.getDataLayout(), + !DCI.isBeforeLegalize()); + SDValue One = DAG.getConstant(1, DL, ShiftVT); + SDValue YShl1 = DAG.getNode(ISD::SHL, DL, N1.getValueType(), Y, One); + if (!DCI.isCalledByLegalizer()) + DCI.AddToWorklist(YShl1.getNode()); + return DAG.getSetCC(DL, VT, X, YShl1, Cond); +} + /// Try to simplify a setcc built with the specified operands and cc. If it is /// unable to simplify it, return a null SDValue. SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, @@ -2209,14 +2662,9 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, SelectionDAG &DAG = DCI.DAG; EVT OpVT = N0.getValueType(); - // These setcc operations always fold. - switch (Cond) { - default: break; - case ISD::SETFALSE: - case ISD::SETFALSE2: return DAG.getBoolConstant(false, dl, VT, OpVT); - case ISD::SETTRUE: - case ISD::SETTRUE2: return DAG.getBoolConstant(true, dl, VT, OpVT); - } + // Constant fold or commute setcc. + if (SDValue Fold = DAG.FoldSetCC(VT, N0, N1, Cond, dl)) + return Fold; // Ensure that the constant occurs on the RHS and fold constant comparisons. // TODO: Handle non-splat vector constants. All undef causes trouble. @@ -2226,6 +2674,17 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, isCondCodeLegal(SwappedCC, N0.getSimpleValueType()))) return DAG.getSetCC(dl, VT, N1, N0, SwappedCC); + // If we have a subtract with the same 2 non-constant operands as this setcc + // -- but in reverse order -- then try to commute the operands of this setcc + // to match. A matching pair of setcc (cmp) and sub may be combined into 1 + // instruction on some targets. + if (!isConstOrConstSplat(N0) && !isConstOrConstSplat(N1) && + (DCI.isBeforeLegalizeOps() || + isCondCodeLegal(SwappedCC, N0.getSimpleValueType())) && + DAG.getNodeIfExists(ISD::SUB, DAG.getVTList(OpVT), { N1, N0 } ) && + !DAG.getNodeIfExists(ISD::SUB, DAG.getVTList(OpVT), { N0, N1 } )) + return DAG.getSetCC(dl, VT, N1, N0, SwappedCC); + if (auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode())) { const APInt &C1 = N1C->getAPIntValue(); @@ -2235,8 +2694,7 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, if (N0.getOpcode() == ISD::SRL && (C1.isNullValue() || C1.isOneValue()) && N0.getOperand(0).getOpcode() == ISD::CTLZ && N0.getOperand(1).getOpcode() == ISD::Constant) { - const APInt &ShAmt - = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); + const APInt &ShAmt = N0.getConstantOperandAPInt(1); if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) && ShAmt == Log2_32(N0.getValueSizeInBits())) { if ((C1 == 0) == (Cond == ISD::SETEQ)) { @@ -2275,7 +2733,21 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, return DAG.getSetCC(dl, VT, And, DAG.getConstant(0, dl, CTVT), CC); } - // TODO: (ctpop x) == 1 -> x && (x & x-1) == 0 iff ctpop is illegal. + // If ctpop is not supported, expand a power-of-2 comparison based on it. + if (C1 == 1 && !isOperationLegalOrCustom(ISD::CTPOP, CTVT) && + (Cond == ISD::SETEQ || Cond == ISD::SETNE)) { + // (ctpop x) == 1 --> (x != 0) && ((x & x-1) == 0) + // (ctpop x) != 1 --> (x == 0) || ((x & x-1) != 0) + SDValue Zero = DAG.getConstant(0, dl, CTVT); + SDValue NegOne = DAG.getAllOnesConstant(dl, CTVT); + ISD::CondCode InvCond = ISD::getSetCCInverse(Cond, true); + SDValue Add = DAG.getNode(ISD::ADD, dl, CTVT, CTOp, NegOne); + SDValue And = DAG.getNode(ISD::AND, dl, CTVT, CTOp, Add); + SDValue LHS = DAG.getSetCC(dl, VT, CTOp, Zero, InvCond); + SDValue RHS = DAG.getSetCC(dl, VT, And, Zero, Cond); + unsigned LogicOpcode = Cond == ISD::SETEQ ? ISD::AND : ISD::OR; + return DAG.getNode(LogicOpcode, dl, VT, LHS, RHS); + } } // (zext x) == C --> x == (trunc C) @@ -2387,8 +2859,7 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, // 8 bits, but have to be careful... if (Lod->getExtensionType() != ISD::NON_EXTLOAD) origWidth = Lod->getMemoryVT().getSizeInBits(); - const APInt &Mask = - cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue(); + const APInt &Mask = N0.getConstantOperandAPInt(1); for (unsigned width = origWidth / 2; width>=8; width /= 2) { APInt newMask = APInt::getLowBitsSet(maskWidth, width); for (unsigned offset=0; offset<origWidth/width; offset++) { @@ -2480,7 +2951,7 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, break; } default: - break; // todo, be more careful with signed comparisons + break; // todo, be more careful with signed comparisons } } else if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG && (Cond == ISD::SETEQ || Cond == ISD::SETNE)) { @@ -2501,7 +2972,7 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, } else { APInt Imm = APInt::getLowBitsSet(ExtDstTyBits, ExtSrcTyBits); ZextOp = DAG.getNode(ISD::AND, dl, Op0Ty, N0.getOperand(0), - DAG.getConstant(Imm, dl, Op0Ty)); + DAG.getConstant(Imm, dl, Op0Ty)); } if (!DCI.isCalledByLegalizer()) DCI.AddToWorklist(ZextOp.getNode()); @@ -2598,6 +3069,18 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, } } + // Given: + // icmp eq/ne (urem %x, %y), 0 + // Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem': + // icmp eq/ne %x, 0 + if (N0.getOpcode() == ISD::UREM && N1C->isNullValue() && + (Cond == ISD::SETEQ || Cond == ISD::SETNE)) { + KnownBits XKnown = DAG.computeKnownBits(N0.getOperand(0)); + KnownBits YKnown = DAG.computeKnownBits(N0.getOperand(1)); + if (XKnown.countMaxPopulation() == 1 && YKnown.countMinPopulation() >= 2) + return DAG.getSetCC(dl, VT, N0.getOperand(0), N1, Cond); + } + if (SDValue V = optimizeSetCCOfSignedTruncationCheck(VT, N0, N1, Cond, DCI, dl)) return V; @@ -2805,25 +3288,9 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, } } - if (isa<ConstantFPSDNode>(N0.getNode())) { - // Constant fold or commute setcc. - SDValue O = DAG.FoldSetCC(VT, N0, N1, Cond, dl); - if (O.getNode()) return O; - } else if (auto *CFP = dyn_cast<ConstantFPSDNode>(N1.getNode())) { - // If the RHS of an FP comparison is a constant, simplify it away in - // some cases. - if (CFP->getValueAPF().isNaN()) { - // If an operand is known to be a nan, we can fold it. - switch (ISD::getUnorderedFlavor(Cond)) { - default: llvm_unreachable("Unknown flavor!"); - case 0: // Known false. - return DAG.getBoolConstant(false, dl, VT, OpVT); - case 1: // Known true. - return DAG.getBoolConstant(true, dl, VT, OpVT); - case 2: // Undefined. - return DAG.getUNDEF(VT); - } - } + if (!isa<ConstantFPSDNode>(N0) && isa<ConstantFPSDNode>(N1)) { + auto *CFP = cast<ConstantFPSDNode>(N1); + assert(!CFP->getValueAPF().isNaN() && "Unexpected NaN value"); // Otherwise, we know the RHS is not a NaN. Simplify the node to drop the // constant if knowing that the operand is non-nan is enough. We prefer to @@ -2883,15 +3350,12 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, if (N0 == N1) { // The sext(setcc()) => setcc() optimization relies on the appropriate // constant being emitted. + assert(!N0.getValueType().isInteger() && + "Integer types should be handled by FoldSetCC"); bool EqTrue = ISD::isTrueWhenEqual(Cond); - - // We can always fold X == X for integer setcc's. - if (N0.getValueType().isInteger()) - return DAG.getBoolConstant(EqTrue, dl, VT, OpVT); - unsigned UOF = ISD::getUnorderedFlavor(Cond); - if (UOF == 2) // FP operators that are undefined on NaNs. + if (UOF == 2) // FP operators that are undefined on NaNs. return DAG.getBoolConstant(EqTrue, dl, VT, OpVT); if (UOF == unsigned(EqTrue)) return DAG.getBoolConstant(EqTrue, dl, VT, OpVT); @@ -2900,7 +3364,7 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode NewCond = UOF == 0 ? ISD::SETO : ISD::SETUO; if (NewCond != Cond && (DCI.isBeforeLegalizeOps() || - isCondCodeLegal(NewCond, N0.getSimpleValueType()))) + isCondCodeLegal(NewCond, N0.getSimpleValueType()))) return DAG.getSetCC(dl, VT, N0, N1, NewCond); } @@ -2969,69 +3433,39 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, LegalRHSImm = isLegalICmpImmediate(RHSC->getSExtValue()); } - // Simplify (X+Z) == X --> Z == 0 + // (X+Y) == X --> Y == 0 and similar folds. // Don't do this if X is an immediate that can fold into a cmp - // instruction and X+Z has other uses. It could be an induction variable + // instruction and X+Y has other uses. It could be an induction variable // chain, and the transform would increase register pressure. - if (!LegalRHSImm || N0.getNode()->hasOneUse()) { - if (N0.getOperand(0) == N1) - return DAG.getSetCC(dl, VT, N0.getOperand(1), - DAG.getConstant(0, dl, N0.getValueType()), Cond); - if (N0.getOperand(1) == N1) { - if (isCommutativeBinOp(N0.getOpcode())) - return DAG.getSetCC(dl, VT, N0.getOperand(0), - DAG.getConstant(0, dl, N0.getValueType()), - Cond); - if (N0.getNode()->hasOneUse()) { - assert(N0.getOpcode() == ISD::SUB && "Unexpected operation!"); - auto &DL = DAG.getDataLayout(); - // (Z-X) == X --> Z == X<<1 - SDValue SH = DAG.getNode( - ISD::SHL, dl, N1.getValueType(), N1, - DAG.getConstant(1, dl, - getShiftAmountTy(N1.getValueType(), DL, - !DCI.isBeforeLegalize()))); - if (!DCI.isCalledByLegalizer()) - DCI.AddToWorklist(SH.getNode()); - return DAG.getSetCC(dl, VT, N0.getOperand(0), SH, Cond); - } - } - } + if (!LegalRHSImm || N0.hasOneUse()) + if (SDValue V = foldSetCCWithBinOp(VT, N0, N1, Cond, dl, DCI)) + return V; } if (N1.getOpcode() == ISD::ADD || N1.getOpcode() == ISD::SUB || - N1.getOpcode() == ISD::XOR) { - // Simplify X == (X+Z) --> Z == 0 - if (N1.getOperand(0) == N0) - return DAG.getSetCC(dl, VT, N1.getOperand(1), - DAG.getConstant(0, dl, N1.getValueType()), Cond); - if (N1.getOperand(1) == N0) { - if (isCommutativeBinOp(N1.getOpcode())) - return DAG.getSetCC(dl, VT, N1.getOperand(0), - DAG.getConstant(0, dl, N1.getValueType()), Cond); - if (N1.getNode()->hasOneUse()) { - assert(N1.getOpcode() == ISD::SUB && "Unexpected operation!"); - auto &DL = DAG.getDataLayout(); - // X == (Z-X) --> X<<1 == Z - SDValue SH = DAG.getNode( - ISD::SHL, dl, N1.getValueType(), N0, - DAG.getConstant(1, dl, getShiftAmountTy(N0.getValueType(), DL, - !DCI.isBeforeLegalize()))); - if (!DCI.isCalledByLegalizer()) - DCI.AddToWorklist(SH.getNode()); - return DAG.getSetCC(dl, VT, SH, N1.getOperand(0), Cond); - } - } - } + N1.getOpcode() == ISD::XOR) + if (SDValue V = foldSetCCWithBinOp(VT, N1, N0, Cond, dl, DCI)) + return V; - if (SDValue V = simplifySetCCWithAnd(VT, N0, N1, Cond, DCI, dl)) + if (SDValue V = foldSetCCWithAnd(VT, N0, N1, Cond, dl, DCI)) return V; } + // Fold remainder of division by a constant. + if (N0.getOpcode() == ISD::UREM && N0.hasOneUse() && + (Cond == ISD::SETEQ || Cond == ISD::SETNE)) { + AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); + + // When division is cheap or optimizing for minimum size, + // fall through to DIVREM creation by skipping this fold. + if (!isIntDivCheap(VT, Attr) && !Attr.hasFnAttribute(Attribute::MinSize)) + if (SDValue Folded = buildUREMEqFold(VT, N0, N1, Cond, DCI, dl)) + return Folded; + } + // Fold away ALL boolean setcc's. - SDValue Temp; if (N0.getValueType().getScalarType() == MVT::i1 && foldBooleans) { - EVT OpVT = N0.getValueType(); + SDValue Temp; switch (Cond) { default: llvm_unreachable("Unknown integer setcc!"); case ISD::SETEQ: // X == Y -> ~(X^Y) @@ -3134,18 +3568,18 @@ TargetLowering::getConstraintType(StringRef Constraint) const { switch (Constraint[0]) { default: break; case 'r': return C_RegisterClass; - case 'm': // memory - case 'o': // offsetable - case 'V': // not offsetable + case 'm': // memory + case 'o': // offsetable + case 'V': // not offsetable return C_Memory; - case 'i': // Simple Integer or Relocatable Constant - case 'n': // Simple Integer - case 'E': // Floating Point Constant - case 'F': // Floating Point Constant - case 's': // Relocatable Constant - case 'p': // Address. - case 'X': // Allow ANY value. - case 'I': // Target registers. + case 'i': // Simple Integer or Relocatable Constant + case 'n': // Simple Integer + case 'E': // Floating Point Constant + case 'F': // Floating Point Constant + case 's': // Relocatable Constant + case 'p': // Address. + case 'X': // Allow ANY value. + case 'I': // Target registers. case 'J': case 'K': case 'L': @@ -3159,7 +3593,7 @@ TargetLowering::getConstraintType(StringRef Constraint) const { } } - if (S > 1 && Constraint[0] == '{' && Constraint[S-1] == '}') { + if (S > 1 && Constraint[0] == '{' && Constraint[S - 1] == '}') { if (S == 8 && Constraint.substr(1, 6) == "memory") // "{memory}" return C_Memory; return C_Register; @@ -3170,14 +3604,20 @@ TargetLowering::getConstraintType(StringRef Constraint) const { /// Try to replace an X constraint, which matches anything, with another that /// has more specific requirements based on the type of the corresponding /// operand. -const char *TargetLowering::LowerXConstraint(EVT ConstraintVT) const{ +const char *TargetLowering::LowerXConstraint(EVT ConstraintVT) const { if (ConstraintVT.isInteger()) return "r"; if (ConstraintVT.isFloatingPoint()) - return "f"; // works for many targets + return "f"; // works for many targets return nullptr; } +SDValue TargetLowering::LowerAsmOutputForConstraint( + SDValue &Chain, SDValue &Flag, SDLoc DL, const AsmOperandInfo &OpInfo, + SelectionDAG &DAG) const { + return SDValue(); +} + /// Lower the specified operand into the Ops vector. /// If it is invalid, don't add anything to Ops. void TargetLowering::LowerAsmOperandForConstraint(SDValue Op, @@ -3191,7 +3631,8 @@ void TargetLowering::LowerAsmOperandForConstraint(SDValue Op, switch (ConstraintLetter) { default: break; case 'X': // Allows any operand; labels (basic block) use this. - if (Op.getOpcode() == ISD::BasicBlock) { + if (Op.getOpcode() == ISD::BasicBlock || + Op.getOpcode() == ISD::TargetBlockAddress) { Ops.push_back(Op); return; } @@ -3199,46 +3640,57 @@ void TargetLowering::LowerAsmOperandForConstraint(SDValue Op, case 'i': // Simple Integer or Relocatable Constant case 'n': // Simple Integer case 's': { // Relocatable Constant - // These operands are interested in values of the form (GV+C), where C may - // be folded in as an offset of GV, or it may be explicitly added. Also, it - // is possible and fine if either GV or C are missing. - ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op); - GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Op); - - // If we have "(add GV, C)", pull out GV/C - if (Op.getOpcode() == ISD::ADD) { - C = dyn_cast<ConstantSDNode>(Op.getOperand(1)); - GA = dyn_cast<GlobalAddressSDNode>(Op.getOperand(0)); - if (!C || !GA) { - C = dyn_cast<ConstantSDNode>(Op.getOperand(0)); - GA = dyn_cast<GlobalAddressSDNode>(Op.getOperand(1)); - } - if (!C || !GA) { - C = nullptr; - GA = nullptr; - } - } - // If we find a valid operand, map to the TargetXXX version so that the - // value itself doesn't get selected. - if (GA) { // Either &GV or &GV+C - if (ConstraintLetter != 'n') { - int64_t Offs = GA->getOffset(); - if (C) Offs += C->getZExtValue(); - Ops.push_back(DAG.getTargetGlobalAddress(GA->getGlobal(), - C ? SDLoc(C) : SDLoc(), - Op.getValueType(), Offs)); - } - return; - } - if (C) { // just C, no GV. - // Simple constants are not allowed for 's'. - if (ConstraintLetter != 's') { + GlobalAddressSDNode *GA; + ConstantSDNode *C; + BlockAddressSDNode *BA; + uint64_t Offset = 0; + + // Match (GA) or (C) or (GA+C) or (GA-C) or ((GA+C)+C) or (((GA+C)+C)+C), + // etc., since getelementpointer is variadic. We can't use + // SelectionDAG::FoldSymbolOffset because it expects the GA to be accessible + // while in this case the GA may be furthest from the root node which is + // likely an ISD::ADD. + while (1) { + if ((GA = dyn_cast<GlobalAddressSDNode>(Op)) && ConstraintLetter != 'n') { + Ops.push_back(DAG.getTargetGlobalAddress(GA->getGlobal(), SDLoc(Op), + GA->getValueType(0), + Offset + GA->getOffset())); + return; + } else if ((C = dyn_cast<ConstantSDNode>(Op)) && + ConstraintLetter != 's') { // gcc prints these as sign extended. Sign extend value to 64 bits // now; without this it would get ZExt'd later in // ScheduleDAGSDNodes::EmitNode, which is very generic. - Ops.push_back(DAG.getTargetConstant(C->getSExtValue(), + bool IsBool = C->getConstantIntValue()->getBitWidth() == 1; + BooleanContent BCont = getBooleanContents(MVT::i64); + ISD::NodeType ExtOpc = IsBool ? getExtendForContent(BCont) + : ISD::SIGN_EXTEND; + int64_t ExtVal = ExtOpc == ISD::ZERO_EXTEND ? C->getZExtValue() + : C->getSExtValue(); + Ops.push_back(DAG.getTargetConstant(Offset + ExtVal, SDLoc(C), MVT::i64)); + return; + } else if ((BA = dyn_cast<BlockAddressSDNode>(Op)) && + ConstraintLetter != 'n') { + Ops.push_back(DAG.getTargetBlockAddress( + BA->getBlockAddress(), BA->getValueType(0), + Offset + BA->getOffset(), BA->getTargetFlags())); + return; + } else { + const unsigned OpCode = Op.getOpcode(); + if (OpCode == ISD::ADD || OpCode == ISD::SUB) { + if ((C = dyn_cast<ConstantSDNode>(Op.getOperand(0)))) + Op = Op.getOperand(1); + // Subtraction is not commutative. + else if (OpCode == ISD::ADD && + (C = dyn_cast<ConstantSDNode>(Op.getOperand(1)))) + Op = Op.getOperand(0); + else + return; + Offset += (OpCode == ISD::ADD ? 1 : -1) * C->getSExtValue(); + continue; + } } return; } @@ -3252,14 +3704,14 @@ TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *RI, StringRef Constraint, MVT VT) const { if (Constraint.empty() || Constraint[0] != '{') - return std::make_pair(0u, static_cast<TargetRegisterClass*>(nullptr)); - assert(*(Constraint.end()-1) == '}' && "Not a brace enclosed constraint?"); + return std::make_pair(0u, static_cast<TargetRegisterClass *>(nullptr)); + assert(*(Constraint.end() - 1) == '}' && "Not a brace enclosed constraint?"); // Remove the braces from around the name. - StringRef RegName(Constraint.data()+1, Constraint.size()-2); + StringRef RegName(Constraint.data() + 1, Constraint.size() - 2); - std::pair<unsigned, const TargetRegisterClass*> R = - std::make_pair(0u, static_cast<const TargetRegisterClass*>(nullptr)); + std::pair<unsigned, const TargetRegisterClass *> R = + std::make_pair(0u, static_cast<const TargetRegisterClass *>(nullptr)); // Figure out which register class contains this reg. for (const TargetRegisterClass *RC : RI->regclasses()) { @@ -3271,8 +3723,8 @@ TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *RI, for (TargetRegisterClass::iterator I = RC->begin(), E = RC->end(); I != E; ++I) { if (RegName.equals_lower(RI->getRegAsmName(*I))) { - std::pair<unsigned, const TargetRegisterClass*> S = - std::make_pair(*I, RC); + std::pair<unsigned, const TargetRegisterClass *> S = + std::make_pair(*I, RC); // If this register class has the requested value type, return it, // otherwise keep searching and return the first class found @@ -3321,8 +3773,8 @@ TargetLowering::ParseConstraints(const DataLayout &DL, // Do a prepass over the constraints, canonicalizing them, and building up the // ConstraintOperands list. - unsigned ArgNo = 0; // ArgNo - The argument of the CallInst. - unsigned ResNo = 0; // ResNo - The result number of the next output. + unsigned ArgNo = 0; // ArgNo - The argument of the CallInst. + unsigned ResNo = 0; // ResNo - The result number of the next output. for (InlineAsm::ConstraintInfo &CI : IA->ParseConstraints()) { ConstraintOperands.emplace_back(std::move(CI)); @@ -3391,7 +3843,7 @@ TargetLowering::ParseConstraints(const DataLayout &DL, case 64: case 128: OpInfo.ConstraintVT = - MVT::getVT(IntegerType::get(OpTy->getContext(), BitSize), true); + MVT::getVT(IntegerType::get(OpTy->getContext(), BitSize), true); break; } } else if (PointerType *PT = dyn_cast<PointerType>(OpTy)) { @@ -3416,8 +3868,8 @@ TargetLowering::ParseConstraints(const DataLayout &DL, for (maIndex = 0; maIndex < maCount; ++maIndex) { int weightSum = 0; for (unsigned cIndex = 0, eIndex = ConstraintOperands.size(); - cIndex != eIndex; ++cIndex) { - AsmOperandInfo& OpInfo = ConstraintOperands[cIndex]; + cIndex != eIndex; ++cIndex) { + AsmOperandInfo &OpInfo = ConstraintOperands[cIndex]; if (OpInfo.Type == InlineAsm::isClobber) continue; @@ -3432,7 +3884,7 @@ TargetLowering::ParseConstraints(const DataLayout &DL, Input.ConstraintVT.isInteger()) || (OpInfo.ConstraintVT.getSizeInBits() != Input.ConstraintVT.getSizeInBits())) { - weightSum = -1; // Can't match. + weightSum = -1; // Can't match. break; } } @@ -3453,8 +3905,8 @@ TargetLowering::ParseConstraints(const DataLayout &DL, // Now select chosen alternative in each constraint. for (unsigned cIndex = 0, eIndex = ConstraintOperands.size(); - cIndex != eIndex; ++cIndex) { - AsmOperandInfo& cInfo = ConstraintOperands[cIndex]; + cIndex != eIndex; ++cIndex) { + AsmOperandInfo &cInfo = ConstraintOperands[cIndex]; if (cInfo.Type == InlineAsm::isClobber) continue; cInfo.selectAlternative(bestMAIndex); @@ -3464,8 +3916,8 @@ TargetLowering::ParseConstraints(const DataLayout &DL, // Check and hook up tied operands, choose constraint code to use. for (unsigned cIndex = 0, eIndex = ConstraintOperands.size(); - cIndex != eIndex; ++cIndex) { - AsmOperandInfo& OpInfo = ConstraintOperands[cIndex]; + cIndex != eIndex; ++cIndex) { + AsmOperandInfo &OpInfo = ConstraintOperands[cIndex]; // If this is an output operand with a matching input operand, look up the // matching input. If their types mismatch, e.g. one is an integer, the @@ -3577,9 +4029,9 @@ TargetLowering::ConstraintWeight weight = CW_Register; break; case 'X': // any operand. - default: - weight = CW_Default; - break; + default: + weight = CW_Default; + break; } return weight; } @@ -3678,6 +4130,9 @@ void TargetLowering::ComputeConstraintToUse(AsmOperandInfo &OpInfo, return; } + if (Op.getNode() && Op.getOpcode() == ISD::TargetBlockAddress) + return; + // Otherwise, try to resolve it to something we know about by looking at // the actual operand type. if (const char *Repl = LowerXConstraint(OpInfo.ConstraintVT)) { @@ -3749,12 +4204,12 @@ static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N, } SDValue TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor, - SelectionDAG &DAG, - SmallVectorImpl<SDNode *> &Created) const { + SelectionDAG &DAG, + SmallVectorImpl<SDNode *> &Created) const { AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (TLI.isIntDivCheap(N->getValueType(0), Attr)) - return SDValue(N,0); // Lower SDIV as SDIV + return SDValue(N, 0); // Lower SDIV as SDIV return SDValue(); } @@ -4000,6 +4455,104 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG, return DAG.getSelect(dl, VT, IsOne, N0, Q); } +/// Given an ISD::UREM used only by an ISD::SETEQ or ISD::SETNE +/// where the divisor is constant and the comparison target is zero, +/// return a DAG expression that will generate the same comparison result +/// using only multiplications, additions and shifts/rotations. +/// Ref: "Hacker's Delight" 10-17. +SDValue TargetLowering::buildUREMEqFold(EVT SETCCVT, SDValue REMNode, + SDValue CompTargetNode, + ISD::CondCode Cond, + DAGCombinerInfo &DCI, + const SDLoc &DL) const { + SmallVector<SDNode *, 2> Built; + if (SDValue Folded = prepareUREMEqFold(SETCCVT, REMNode, CompTargetNode, Cond, + DCI, DL, Built)) { + for (SDNode *N : Built) + DCI.AddToWorklist(N); + return Folded; + } + + return SDValue(); +} + +SDValue +TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode, + SDValue CompTargetNode, ISD::CondCode Cond, + DAGCombinerInfo &DCI, const SDLoc &DL, + SmallVectorImpl<SDNode *> &Created) const { + // fold (seteq/ne (urem N, D), 0) -> (setule/ugt (rotr (mul N, P), K), Q) + // - D must be constant with D = D0 * 2^K where D0 is odd and D0 != 1 + // - P is the multiplicative inverse of D0 modulo 2^W + // - Q = floor((2^W - 1) / D0) + // where W is the width of the common type of N and D. + assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) && + "Only applicable for (in)equality comparisons."); + + EVT VT = REMNode.getValueType(); + + // If MUL is unavailable, we cannot proceed in any case. + if (!isOperationLegalOrCustom(ISD::MUL, VT)) + return SDValue(); + + // TODO: Add non-uniform constant support. + ConstantSDNode *Divisor = isConstOrConstSplat(REMNode->getOperand(1)); + ConstantSDNode *CompTarget = isConstOrConstSplat(CompTargetNode); + if (!Divisor || !CompTarget || Divisor->isNullValue() || + !CompTarget->isNullValue()) + return SDValue(); + + const APInt &D = Divisor->getAPIntValue(); + + // Decompose D into D0 * 2^K + unsigned K = D.countTrailingZeros(); + bool DivisorIsEven = (K != 0); + APInt D0 = D.lshr(K); + + // The fold is invalid when D0 == 1. + // This is reachable because visitSetCC happens before visitREM. + if (D0.isOneValue()) + return SDValue(); + + // P = inv(D0, 2^W) + // 2^W requires W + 1 bits, so we have to extend and then truncate. + unsigned W = D.getBitWidth(); + APInt P = D0.zext(W + 1) + .multiplicativeInverse(APInt::getSignedMinValue(W + 1)) + .trunc(W); + assert(!P.isNullValue() && "No multiplicative inverse!"); // unreachable + assert((D0 * P).isOneValue() && "Multiplicative inverse sanity check."); + + // Q = floor((2^W - 1) / D) + APInt Q = APInt::getAllOnesValue(W).udiv(D); + + SelectionDAG &DAG = DCI.DAG; + + SDValue PVal = DAG.getConstant(P, DL, VT); + SDValue QVal = DAG.getConstant(Q, DL, VT); + // (mul N, P) + SDValue Op1 = DAG.getNode(ISD::MUL, DL, VT, REMNode->getOperand(0), PVal); + Created.push_back(Op1.getNode()); + + // Rotate right only if D was even. + if (DivisorIsEven) { + // We need ROTR to do this. + if (!isOperationLegalOrCustom(ISD::ROTR, VT)) + return SDValue(); + SDValue ShAmt = + DAG.getConstant(K, DL, getShiftAmountTy(VT, DAG.getDataLayout())); + SDNodeFlags Flags; + Flags.setExact(true); + // UREM: (rotr (mul N, P), K) + Op1 = DAG.getNode(ISD::ROTR, DL, VT, Op1, ShAmt, Flags); + Created.push_back(Op1.getNode()); + } + + // UREM: (setule/setugt (rotr (mul N, P), K), Q) + return DAG.getSetCC(DL, SETCCVT, Op1, QVal, + ((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT)); +} + bool TargetLowering:: verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const { if (!isa<ConstantSDNode>(Op.getOperand(0))) { @@ -4308,7 +4861,7 @@ bool TargetLowering::expandROT(SDNode *Node, SDValue &Result, } bool TargetLowering::expandFP_TO_SINT(SDNode *Node, SDValue &Result, - SelectionDAG &DAG) const { + SelectionDAG &DAG) const { SDValue Src = Node->getOperand(0); EVT SrcVT = Src.getValueType(); EVT DstVT = Node->getValueType(0); @@ -4320,7 +4873,7 @@ bool TargetLowering::expandFP_TO_SINT(SDNode *Node, SDValue &Result, // Expand f32 -> i64 conversion // This algorithm comes from compiler-rt's implementation of fixsfdi: - // https://github.com/llvm-mirror/compiler-rt/blob/master/lib/builtins/fixsfdi.c + // https://github.com/llvm/llvm-project/blob/master/compiler-rt/lib/builtins/fixsfdi.c unsigned SrcEltBits = SrcVT.getScalarSizeInBits(); EVT IntVT = SrcVT.changeTypeToInteger(); EVT IntShVT = getShiftAmountTy(IntVT, DAG.getDataLayout()); @@ -4544,6 +5097,17 @@ SDValue TargetLowering::expandFMINNUM_FMAXNUM(SDNode *Node, return DAG.getNode(NewOp, dl, VT, Quiet0, Quiet1, Node->getFlags()); } + // If the target has FMINIMUM/FMAXIMUM but not FMINNUM/FMAXNUM use that + // instead if there are no NaNs. + if (Node->getFlags().hasNoNaNs()) { + unsigned IEEE2018Op = + Node->getOpcode() == ISD::FMINNUM ? ISD::FMINIMUM : ISD::FMAXIMUM; + if (isOperationLegalOrCustom(IEEE2018Op, VT)) { + return DAG.getNode(IEEE2018Op, dl, VT, Node->getOperand(0), + Node->getOperand(1), Node->getFlags()); + } + } + return SDValue(); } @@ -4771,7 +5335,7 @@ SDValue TargetLowering::scalarizeVectorLoad(LoadSDNode *LD, SDValue NewChain = DAG.getNode(ISD::TokenFactor, SL, MVT::Other, LoadChains); SDValue Value = DAG.getBuildVector(LD->getValueType(0), SL, Vals); - return DAG.getMergeValues({ Value, NewChain }, SL); + return DAG.getMergeValues({Value, NewChain}, SL); } SDValue TargetLowering::scalarizeVectorStore(StoreSDNode *ST, @@ -4826,7 +5390,7 @@ SDValue TargetLowering::scalarizeVectorStore(StoreSDNode *ST, // Store Stride in bytes unsigned Stride = MemSclVT.getSizeInBits() / 8; - assert (Stride && "Zero stride!"); + assert(Stride && "Zero stride!"); // Extract each of the elements from the original vector and save them into // memory individually. SmallVector<SDValue, 8> Stores; @@ -5013,17 +5577,16 @@ SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST, EVT VT = Val.getValueType(); int Alignment = ST->getAlignment(); auto &MF = DAG.getMachineFunction(); - EVT MemVT = ST->getMemoryVT(); + EVT StoreMemVT = ST->getMemoryVT(); SDLoc dl(ST); - if (MemVT.isFloatingPoint() || MemVT.isVector()) { + if (StoreMemVT.isFloatingPoint() || StoreMemVT.isVector()) { EVT intVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits()); if (isTypeLegal(intVT)) { if (!isOperationLegalOrCustom(ISD::STORE, intVT) && - MemVT.isVector()) { + StoreMemVT.isVector()) { // Scalarize the store and let the individual components be handled. SDValue Result = scalarizeVectorStore(ST, DAG); - return Result; } // Expand to a bitconvert of the value to the integer type of the @@ -5036,24 +5599,22 @@ SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST, } // Do a (aligned) store to a stack slot, then copy from the stack slot // to the final destination using (unaligned) integer loads and stores. - EVT StoredVT = ST->getMemoryVT(); - MVT RegVT = - getRegisterType(*DAG.getContext(), - EVT::getIntegerVT(*DAG.getContext(), - StoredVT.getSizeInBits())); + MVT RegVT = getRegisterType( + *DAG.getContext(), + EVT::getIntegerVT(*DAG.getContext(), StoreMemVT.getSizeInBits())); EVT PtrVT = Ptr.getValueType(); - unsigned StoredBytes = StoredVT.getStoreSize(); + unsigned StoredBytes = StoreMemVT.getStoreSize(); unsigned RegBytes = RegVT.getSizeInBits() / 8; unsigned NumRegs = (StoredBytes + RegBytes - 1) / RegBytes; // Make sure the stack slot is also aligned for the register type. - SDValue StackPtr = DAG.CreateStackTemporary(StoredVT, RegVT); + SDValue StackPtr = DAG.CreateStackTemporary(StoreMemVT, RegVT); auto FrameIndex = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex(); // Perform the original store, only redirected to the stack slot. SDValue Store = DAG.getTruncStore( Chain, dl, Val, StackPtr, - MachinePointerInfo::getFixedStack(MF, FrameIndex, 0), StoredVT); + MachinePointerInfo::getFixedStack(MF, FrameIndex, 0), StoreMemVT); EVT StackPtrVT = StackPtr.getValueType(); @@ -5082,17 +5643,17 @@ SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST, // The last store may be partial. Do a truncating store. On big-endian // machines this requires an extending load from the stack slot to ensure // that the bits are in the right place. - EVT MemVT = EVT::getIntegerVT(*DAG.getContext(), - 8 * (StoredBytes - Offset)); + EVT LoadMemVT = + EVT::getIntegerVT(*DAG.getContext(), 8 * (StoredBytes - Offset)); // Load from the stack slot. SDValue Load = DAG.getExtLoad( ISD::EXTLOAD, dl, RegVT, Store, StackPtr, - MachinePointerInfo::getFixedStack(MF, FrameIndex, Offset), MemVT); + MachinePointerInfo::getFixedStack(MF, FrameIndex, Offset), LoadMemVT); Stores.push_back( DAG.getTruncStore(Load.getValue(1), dl, Load, Ptr, - ST->getPointerInfo().getWithOffset(Offset), MemVT, + ST->getPointerInfo().getWithOffset(Offset), LoadMemVT, MinAlign(ST->getAlignment(), Offset), ST->getMemOperand()->getFlags(), ST->getAAInfo())); // The order of the stores doesn't matter - say it with a TokenFactor. @@ -5100,18 +5661,16 @@ SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST, return Result; } - assert(ST->getMemoryVT().isInteger() && - !ST->getMemoryVT().isVector() && + assert(StoreMemVT.isInteger() && !StoreMemVT.isVector() && "Unaligned store of unknown type."); // Get the half-size VT - EVT NewStoredVT = ST->getMemoryVT().getHalfSizedIntegerVT(*DAG.getContext()); + EVT NewStoredVT = StoreMemVT.getHalfSizedIntegerVT(*DAG.getContext()); int NumBits = NewStoredVT.getSizeInBits(); int IncrementSize = NumBits / 8; // Divide the stored value in two parts. - SDValue ShiftAmount = - DAG.getConstant(NumBits, dl, getShiftAmountTy(Val.getValueType(), - DAG.getDataLayout())); + SDValue ShiftAmount = DAG.getConstant( + NumBits, dl, getShiftAmountTy(Val.getValueType(), DAG.getDataLayout())); SDValue Lo = Val; SDValue Hi = DAG.getNode(ISD::SRL, dl, VT, Val, ShiftAmount); @@ -5130,7 +5689,7 @@ SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST, ST->getMemOperand()->getFlags(), ST->getAAInfo()); SDValue Result = - DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Store1, Store2); + DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Store1, Store2); return Result; } @@ -5242,7 +5801,7 @@ SDValue TargetLowering::LowerToTLSEmulatedModel(const GlobalAddressSDNode *GA, // TLSADDR will be codegen'ed as call. Inform MFI that function has calls. // At last for X86 targets, maybe good for other targets too? MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); - MFI.setAdjustsStack(true); // Is this only for X86 target? + MFI.setAdjustsStack(true); // Is this only for X86 target? MFI.setHasCalls(true); assert((GA->getOffset() == 0) && @@ -5282,15 +5841,19 @@ SDValue TargetLowering::expandAddSubSat(SDNode *Node, SelectionDAG &DAG) const { EVT VT = LHS.getValueType(); SDLoc dl(Node); + assert(VT == RHS.getValueType() && "Expected operands to be the same type"); + assert(VT.isInteger() && "Expected operands to be integers"); + // usub.sat(a, b) -> umax(a, b) - b if (Opcode == ISD::USUBSAT && isOperationLegalOrCustom(ISD::UMAX, VT)) { SDValue Max = DAG.getNode(ISD::UMAX, dl, VT, LHS, RHS); return DAG.getNode(ISD::SUB, dl, VT, Max, RHS); } - if (VT.isVector()) { - // TODO: Consider not scalarizing here. - return SDValue(); + if (Opcode == ISD::UADDSAT && isOperationLegalOrCustom(ISD::UMIN, VT)) { + SDValue InvRHS = DAG.getNOT(dl, RHS, VT); + SDValue Min = DAG.getNode(ISD::UMIN, dl, VT, LHS, InvRHS); + return DAG.getNode(ISD::ADD, dl, VT, Min, RHS); } unsigned OverflowOp; @@ -5312,96 +5875,410 @@ SDValue TargetLowering::expandAddSubSat(SDNode *Node, SelectionDAG &DAG) const { "addition or subtraction node."); } - assert(LHS.getValueType().isScalarInteger() && - "Expected operands to be integers. Vector of int arguments should " - "already be unrolled."); - assert(RHS.getValueType().isScalarInteger() && - "Expected operands to be integers. Vector of int arguments should " - "already be unrolled."); - assert(LHS.getValueType() == RHS.getValueType() && - "Expected both operands to be the same type"); - - unsigned BitWidth = LHS.getValueSizeInBits(); - EVT ResultType = LHS.getValueType(); - EVT BoolVT = - getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), ResultType); - SDValue Result = - DAG.getNode(OverflowOp, dl, DAG.getVTList(ResultType, BoolVT), LHS, RHS); + unsigned BitWidth = LHS.getScalarValueSizeInBits(); + EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue Result = DAG.getNode(OverflowOp, dl, DAG.getVTList(VT, BoolVT), + LHS, RHS); SDValue SumDiff = Result.getValue(0); SDValue Overflow = Result.getValue(1); - SDValue Zero = DAG.getConstant(0, dl, ResultType); + SDValue Zero = DAG.getConstant(0, dl, VT); + SDValue AllOnes = DAG.getAllOnesConstant(dl, VT); if (Opcode == ISD::UADDSAT) { - // Just need to check overflow for SatMax. - APInt MaxVal = APInt::getMaxValue(BitWidth); - SDValue SatMax = DAG.getConstant(MaxVal, dl, ResultType); - return DAG.getSelect(dl, ResultType, Overflow, SatMax, SumDiff); + if (getBooleanContents(VT) == ZeroOrNegativeOneBooleanContent) { + // (LHS + RHS) | OverflowMask + SDValue OverflowMask = DAG.getSExtOrTrunc(Overflow, dl, VT); + return DAG.getNode(ISD::OR, dl, VT, SumDiff, OverflowMask); + } + // Overflow ? 0xffff.... : (LHS + RHS) + return DAG.getSelect(dl, VT, Overflow, AllOnes, SumDiff); } else if (Opcode == ISD::USUBSAT) { - // Just need to check overflow for SatMin. - APInt MinVal = APInt::getMinValue(BitWidth); - SDValue SatMin = DAG.getConstant(MinVal, dl, ResultType); - return DAG.getSelect(dl, ResultType, Overflow, SatMin, SumDiff); + if (getBooleanContents(VT) == ZeroOrNegativeOneBooleanContent) { + // (LHS - RHS) & ~OverflowMask + SDValue OverflowMask = DAG.getSExtOrTrunc(Overflow, dl, VT); + SDValue Not = DAG.getNOT(dl, OverflowMask, VT); + return DAG.getNode(ISD::AND, dl, VT, SumDiff, Not); + } + // Overflow ? 0 : (LHS - RHS) + return DAG.getSelect(dl, VT, Overflow, Zero, SumDiff); } else { // SatMax -> Overflow && SumDiff < 0 // SatMin -> Overflow && SumDiff >= 0 APInt MinVal = APInt::getSignedMinValue(BitWidth); APInt MaxVal = APInt::getSignedMaxValue(BitWidth); - SDValue SatMin = DAG.getConstant(MinVal, dl, ResultType); - SDValue SatMax = DAG.getConstant(MaxVal, dl, ResultType); + SDValue SatMin = DAG.getConstant(MinVal, dl, VT); + SDValue SatMax = DAG.getConstant(MaxVal, dl, VT); SDValue SumNeg = DAG.getSetCC(dl, BoolVT, SumDiff, Zero, ISD::SETLT); - Result = DAG.getSelect(dl, ResultType, SumNeg, SatMax, SatMin); - return DAG.getSelect(dl, ResultType, Overflow, Result, SumDiff); + Result = DAG.getSelect(dl, VT, SumNeg, SatMax, SatMin); + return DAG.getSelect(dl, VT, Overflow, Result, SumDiff); } } SDValue -TargetLowering::getExpandedFixedPointMultiplication(SDNode *Node, - SelectionDAG &DAG) const { - assert(Node->getOpcode() == ISD::SMULFIX && "Expected opcode to be SMULFIX."); - assert(Node->getNumOperands() == 3 && - "Expected signed fixed point multiplication to have 3 operands."); +TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const { + assert((Node->getOpcode() == ISD::SMULFIX || + Node->getOpcode() == ISD::UMULFIX || + Node->getOpcode() == ISD::SMULFIXSAT) && + "Expected a fixed point multiplication opcode"); SDLoc dl(Node); SDValue LHS = Node->getOperand(0); SDValue RHS = Node->getOperand(1); - assert(LHS.getValueType().isScalarInteger() && - "Expected operands to be integers. Vector of int arguments should " - "already be unrolled."); - assert(RHS.getValueType().isScalarInteger() && - "Expected operands to be integers. Vector of int arguments should " - "already be unrolled."); - assert(LHS.getValueType() == RHS.getValueType() && - "Expected both operands to be the same type"); - - unsigned Scale = Node->getConstantOperandVal(2); EVT VT = LHS.getValueType(); - assert(Scale < VT.getScalarSizeInBits() && - "Expected scale to be less than the number of bits."); + unsigned Scale = Node->getConstantOperandVal(2); + bool Saturating = Node->getOpcode() == ISD::SMULFIXSAT; + EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + unsigned VTSize = VT.getScalarSizeInBits(); + + if (!Scale) { + // [us]mul.fix(a, b, 0) -> mul(a, b) + if (!Saturating && isOperationLegalOrCustom(ISD::MUL, VT)) { + return DAG.getNode(ISD::MUL, dl, VT, LHS, RHS); + } else if (Saturating && isOperationLegalOrCustom(ISD::SMULO, VT)) { + SDValue Result = + DAG.getNode(ISD::SMULO, dl, DAG.getVTList(VT, BoolVT), LHS, RHS); + SDValue Product = Result.getValue(0); + SDValue Overflow = Result.getValue(1); + SDValue Zero = DAG.getConstant(0, dl, VT); + + APInt MinVal = APInt::getSignedMinValue(VTSize); + APInt MaxVal = APInt::getSignedMaxValue(VTSize); + SDValue SatMin = DAG.getConstant(MinVal, dl, VT); + SDValue SatMax = DAG.getConstant(MaxVal, dl, VT); + SDValue ProdNeg = DAG.getSetCC(dl, BoolVT, Product, Zero, ISD::SETLT); + Result = DAG.getSelect(dl, VT, ProdNeg, SatMax, SatMin); + return DAG.getSelect(dl, VT, Overflow, Result, Product); + } + } - if (!Scale) - return DAG.getNode(ISD::MUL, dl, VT, LHS, RHS); + bool Signed = + Node->getOpcode() == ISD::SMULFIX || Node->getOpcode() == ISD::SMULFIXSAT; + assert(((Signed && Scale < VTSize) || (!Signed && Scale <= VTSize)) && + "Expected scale to be less than the number of bits if signed or at " + "most the number of bits if unsigned."); + assert(LHS.getValueType() == RHS.getValueType() && + "Expected both operands to be the same type"); // Get the upper and lower bits of the result. SDValue Lo, Hi; - if (isOperationLegalOrCustom(ISD::SMUL_LOHI, VT)) { - SDValue Result = - DAG.getNode(ISD::SMUL_LOHI, dl, DAG.getVTList(VT, VT), LHS, RHS); + unsigned LoHiOp = Signed ? ISD::SMUL_LOHI : ISD::UMUL_LOHI; + unsigned HiOp = Signed ? ISD::MULHS : ISD::MULHU; + if (isOperationLegalOrCustom(LoHiOp, VT)) { + SDValue Result = DAG.getNode(LoHiOp, dl, DAG.getVTList(VT, VT), LHS, RHS); Lo = Result.getValue(0); Hi = Result.getValue(1); - } else if (isOperationLegalOrCustom(ISD::MULHS, VT)) { + } else if (isOperationLegalOrCustom(HiOp, VT)) { Lo = DAG.getNode(ISD::MUL, dl, VT, LHS, RHS); - Hi = DAG.getNode(ISD::MULHS, dl, VT, LHS, RHS); + Hi = DAG.getNode(HiOp, dl, VT, LHS, RHS); + } else if (VT.isVector()) { + return SDValue(); } else { - report_fatal_error("Unable to expand signed fixed point multiplication."); + report_fatal_error("Unable to expand fixed point multiplication."); } + if (Scale == VTSize) + // Result is just the top half since we'd be shifting by the width of the + // operand. + return Hi; + // The result will need to be shifted right by the scale since both operands // are scaled. The result is given to us in 2 halves, so we only want part of // both in the result. EVT ShiftTy = getShiftAmountTy(VT, DAG.getDataLayout()); - Lo = DAG.getNode(ISD::SRL, dl, VT, Lo, DAG.getConstant(Scale, dl, ShiftTy)); - Hi = DAG.getNode( - ISD::SHL, dl, VT, Hi, - DAG.getConstant(VT.getScalarSizeInBits() - Scale, dl, ShiftTy)); - return DAG.getNode(ISD::OR, dl, VT, Lo, Hi); + SDValue Result = DAG.getNode(ISD::FSHR, dl, VT, Hi, Lo, + DAG.getConstant(Scale, dl, ShiftTy)); + if (!Saturating) + return Result; + + unsigned OverflowBits = VTSize - Scale + 1; // +1 for the sign + SDValue HiMask = + DAG.getConstant(APInt::getHighBitsSet(VTSize, OverflowBits), dl, VT); + SDValue LoMask = DAG.getConstant( + APInt::getLowBitsSet(VTSize, VTSize - OverflowBits), dl, VT); + APInt MaxVal = APInt::getSignedMaxValue(VTSize); + APInt MinVal = APInt::getSignedMinValue(VTSize); + + Result = DAG.getSelectCC(dl, Hi, LoMask, + DAG.getConstant(MaxVal, dl, VT), Result, + ISD::SETGT); + return DAG.getSelectCC(dl, Hi, HiMask, + DAG.getConstant(MinVal, dl, VT), Result, + ISD::SETLT); +} + +void TargetLowering::expandUADDSUBO( + SDNode *Node, SDValue &Result, SDValue &Overflow, SelectionDAG &DAG) const { + SDLoc dl(Node); + SDValue LHS = Node->getOperand(0); + SDValue RHS = Node->getOperand(1); + bool IsAdd = Node->getOpcode() == ISD::UADDO; + + // If ADD/SUBCARRY is legal, use that instead. + unsigned OpcCarry = IsAdd ? ISD::ADDCARRY : ISD::SUBCARRY; + if (isOperationLegalOrCustom(OpcCarry, Node->getValueType(0))) { + SDValue CarryIn = DAG.getConstant(0, dl, Node->getValueType(1)); + SDValue NodeCarry = DAG.getNode(OpcCarry, dl, Node->getVTList(), + { LHS, RHS, CarryIn }); + Result = SDValue(NodeCarry.getNode(), 0); + Overflow = SDValue(NodeCarry.getNode(), 1); + return; + } + + Result = DAG.getNode(IsAdd ? ISD::ADD : ISD::SUB, dl, + LHS.getValueType(), LHS, RHS); + + EVT ResultType = Node->getValueType(1); + EVT SetCCType = getSetCCResultType( + DAG.getDataLayout(), *DAG.getContext(), Node->getValueType(0)); + ISD::CondCode CC = IsAdd ? ISD::SETULT : ISD::SETUGT; + SDValue SetCC = DAG.getSetCC(dl, SetCCType, Result, LHS, CC); + Overflow = DAG.getBoolExtOrTrunc(SetCC, dl, ResultType, ResultType); +} + +void TargetLowering::expandSADDSUBO( + SDNode *Node, SDValue &Result, SDValue &Overflow, SelectionDAG &DAG) const { + SDLoc dl(Node); + SDValue LHS = Node->getOperand(0); + SDValue RHS = Node->getOperand(1); + bool IsAdd = Node->getOpcode() == ISD::SADDO; + + Result = DAG.getNode(IsAdd ? ISD::ADD : ISD::SUB, dl, + LHS.getValueType(), LHS, RHS); + + EVT ResultType = Node->getValueType(1); + EVT OType = getSetCCResultType( + DAG.getDataLayout(), *DAG.getContext(), Node->getValueType(0)); + + // If SADDSAT/SSUBSAT is legal, compare results to detect overflow. + unsigned OpcSat = IsAdd ? ISD::SADDSAT : ISD::SSUBSAT; + if (isOperationLegalOrCustom(OpcSat, LHS.getValueType())) { + SDValue Sat = DAG.getNode(OpcSat, dl, LHS.getValueType(), LHS, RHS); + SDValue SetCC = DAG.getSetCC(dl, OType, Result, Sat, ISD::SETNE); + Overflow = DAG.getBoolExtOrTrunc(SetCC, dl, ResultType, ResultType); + return; + } + + SDValue Zero = DAG.getConstant(0, dl, LHS.getValueType()); + + // LHSSign -> LHS >= 0 + // RHSSign -> RHS >= 0 + // SumSign -> Result >= 0 + // + // Add: + // Overflow -> (LHSSign == RHSSign) && (LHSSign != SumSign) + // Sub: + // Overflow -> (LHSSign != RHSSign) && (LHSSign != SumSign) + SDValue LHSSign = DAG.getSetCC(dl, OType, LHS, Zero, ISD::SETGE); + SDValue RHSSign = DAG.getSetCC(dl, OType, RHS, Zero, ISD::SETGE); + SDValue SignsMatch = DAG.getSetCC(dl, OType, LHSSign, RHSSign, + IsAdd ? ISD::SETEQ : ISD::SETNE); + + SDValue SumSign = DAG.getSetCC(dl, OType, Result, Zero, ISD::SETGE); + SDValue SumSignNE = DAG.getSetCC(dl, OType, LHSSign, SumSign, ISD::SETNE); + + SDValue Cmp = DAG.getNode(ISD::AND, dl, OType, SignsMatch, SumSignNE); + Overflow = DAG.getBoolExtOrTrunc(Cmp, dl, ResultType, ResultType); +} + +bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result, + SDValue &Overflow, SelectionDAG &DAG) const { + SDLoc dl(Node); + EVT VT = Node->getValueType(0); + EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue LHS = Node->getOperand(0); + SDValue RHS = Node->getOperand(1); + bool isSigned = Node->getOpcode() == ISD::SMULO; + + // For power-of-two multiplications we can use a simpler shift expansion. + if (ConstantSDNode *RHSC = isConstOrConstSplat(RHS)) { + const APInt &C = RHSC->getAPIntValue(); + // mulo(X, 1 << S) -> { X << S, (X << S) >> S != X } + if (C.isPowerOf2()) { + // smulo(x, signed_min) is same as umulo(x, signed_min). + bool UseArithShift = isSigned && !C.isMinSignedValue(); + EVT ShiftAmtTy = getShiftAmountTy(VT, DAG.getDataLayout()); + SDValue ShiftAmt = DAG.getConstant(C.logBase2(), dl, ShiftAmtTy); + Result = DAG.getNode(ISD::SHL, dl, VT, LHS, ShiftAmt); + Overflow = DAG.getSetCC(dl, SetCCVT, + DAG.getNode(UseArithShift ? ISD::SRA : ISD::SRL, + dl, VT, Result, ShiftAmt), + LHS, ISD::SETNE); + return true; + } + } + + EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getScalarSizeInBits() * 2); + if (VT.isVector()) + WideVT = EVT::getVectorVT(*DAG.getContext(), WideVT, + VT.getVectorNumElements()); + + SDValue BottomHalf; + SDValue TopHalf; + static const unsigned Ops[2][3] = + { { ISD::MULHU, ISD::UMUL_LOHI, ISD::ZERO_EXTEND }, + { ISD::MULHS, ISD::SMUL_LOHI, ISD::SIGN_EXTEND }}; + if (isOperationLegalOrCustom(Ops[isSigned][0], VT)) { + BottomHalf = DAG.getNode(ISD::MUL, dl, VT, LHS, RHS); + TopHalf = DAG.getNode(Ops[isSigned][0], dl, VT, LHS, RHS); + } else if (isOperationLegalOrCustom(Ops[isSigned][1], VT)) { + BottomHalf = DAG.getNode(Ops[isSigned][1], dl, DAG.getVTList(VT, VT), LHS, + RHS); + TopHalf = BottomHalf.getValue(1); + } else if (isTypeLegal(WideVT)) { + LHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, LHS); + RHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, RHS); + SDValue Mul = DAG.getNode(ISD::MUL, dl, WideVT, LHS, RHS); + BottomHalf = DAG.getNode(ISD::TRUNCATE, dl, VT, Mul); + SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits(), dl, + getShiftAmountTy(WideVT, DAG.getDataLayout())); + TopHalf = DAG.getNode(ISD::TRUNCATE, dl, VT, + DAG.getNode(ISD::SRL, dl, WideVT, Mul, ShiftAmt)); + } else { + if (VT.isVector()) + return false; + + // We can fall back to a libcall with an illegal type for the MUL if we + // have a libcall big enough. + // Also, we can fall back to a division in some cases, but that's a big + // performance hit in the general case. + RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL; + if (WideVT == MVT::i16) + LC = RTLIB::MUL_I16; + else if (WideVT == MVT::i32) + LC = RTLIB::MUL_I32; + else if (WideVT == MVT::i64) + LC = RTLIB::MUL_I64; + else if (WideVT == MVT::i128) + LC = RTLIB::MUL_I128; + assert(LC != RTLIB::UNKNOWN_LIBCALL && "Cannot expand this operation!"); + + SDValue HiLHS; + SDValue HiRHS; + if (isSigned) { + // The high part is obtained by SRA'ing all but one of the bits of low + // part. + unsigned LoSize = VT.getSizeInBits(); + HiLHS = + DAG.getNode(ISD::SRA, dl, VT, LHS, + DAG.getConstant(LoSize - 1, dl, + getPointerTy(DAG.getDataLayout()))); + HiRHS = + DAG.getNode(ISD::SRA, dl, VT, RHS, + DAG.getConstant(LoSize - 1, dl, + getPointerTy(DAG.getDataLayout()))); + } else { + HiLHS = DAG.getConstant(0, dl, VT); + HiRHS = DAG.getConstant(0, dl, VT); + } + + // Here we're passing the 2 arguments explicitly as 4 arguments that are + // pre-lowered to the correct types. This all depends upon WideVT not + // being a legal type for the architecture and thus has to be split to + // two arguments. + SDValue Ret; + if (shouldSplitFunctionArgumentsAsLittleEndian(DAG.getDataLayout())) { + // Halves of WideVT are packed into registers in different order + // depending on platform endianness. This is usually handled by + // the C calling convention, but we can't defer to it in + // the legalizer. + SDValue Args[] = { LHS, HiLHS, RHS, HiRHS }; + Ret = makeLibCall(DAG, LC, WideVT, Args, isSigned, dl, + /* doesNotReturn */ false, /* isReturnValueUsed */ true, + /* isPostTypeLegalization */ true).first; + } else { + SDValue Args[] = { HiLHS, LHS, HiRHS, RHS }; + Ret = makeLibCall(DAG, LC, WideVT, Args, isSigned, dl, + /* doesNotReturn */ false, /* isReturnValueUsed */ true, + /* isPostTypeLegalization */ true).first; + } + assert(Ret.getOpcode() == ISD::MERGE_VALUES && + "Ret value is a collection of constituent nodes holding result."); + if (DAG.getDataLayout().isLittleEndian()) { + // Same as above. + BottomHalf = Ret.getOperand(0); + TopHalf = Ret.getOperand(1); + } else { + BottomHalf = Ret.getOperand(1); + TopHalf = Ret.getOperand(0); + } + } + + Result = BottomHalf; + if (isSigned) { + SDValue ShiftAmt = DAG.getConstant( + VT.getScalarSizeInBits() - 1, dl, + getShiftAmountTy(BottomHalf.getValueType(), DAG.getDataLayout())); + SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, BottomHalf, ShiftAmt); + Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf, Sign, ISD::SETNE); + } else { + Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf, + DAG.getConstant(0, dl, VT), ISD::SETNE); + } + + // Truncate the result if SetCC returns a larger type than needed. + EVT RType = Node->getValueType(1); + if (RType.getSizeInBits() < Overflow.getValueSizeInBits()) + Overflow = DAG.getNode(ISD::TRUNCATE, dl, RType, Overflow); + + assert(RType.getSizeInBits() == Overflow.getValueSizeInBits() && + "Unexpected result type for S/UMULO legalization"); + return true; +} + +SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const { + SDLoc dl(Node); + bool NoNaN = Node->getFlags().hasNoNaNs(); + unsigned BaseOpcode = 0; + switch (Node->getOpcode()) { + default: llvm_unreachable("Expected VECREDUCE opcode"); + case ISD::VECREDUCE_FADD: BaseOpcode = ISD::FADD; break; + case ISD::VECREDUCE_FMUL: BaseOpcode = ISD::FMUL; break; + case ISD::VECREDUCE_ADD: BaseOpcode = ISD::ADD; break; + case ISD::VECREDUCE_MUL: BaseOpcode = ISD::MUL; break; + case ISD::VECREDUCE_AND: BaseOpcode = ISD::AND; break; + case ISD::VECREDUCE_OR: BaseOpcode = ISD::OR; break; + case ISD::VECREDUCE_XOR: BaseOpcode = ISD::XOR; break; + case ISD::VECREDUCE_SMAX: BaseOpcode = ISD::SMAX; break; + case ISD::VECREDUCE_SMIN: BaseOpcode = ISD::SMIN; break; + case ISD::VECREDUCE_UMAX: BaseOpcode = ISD::UMAX; break; + case ISD::VECREDUCE_UMIN: BaseOpcode = ISD::UMIN; break; + case ISD::VECREDUCE_FMAX: + BaseOpcode = NoNaN ? ISD::FMAXNUM : ISD::FMAXIMUM; + break; + case ISD::VECREDUCE_FMIN: + BaseOpcode = NoNaN ? ISD::FMINNUM : ISD::FMINIMUM; + break; + } + + SDValue Op = Node->getOperand(0); + EVT VT = Op.getValueType(); + + // Try to use a shuffle reduction for power of two vectors. + if (VT.isPow2VectorType()) { + while (VT.getVectorNumElements() > 1) { + EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext()); + if (!isOperationLegalOrCustom(BaseOpcode, HalfVT)) + break; + + SDValue Lo, Hi; + std::tie(Lo, Hi) = DAG.SplitVector(Op, dl); + Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi); + VT = HalfVT; + } + } + + EVT EltVT = VT.getVectorElementType(); + unsigned NumElts = VT.getVectorNumElements(); + + SmallVector<SDValue, 8> Ops; + DAG.ExtractVectorElements(Op, Ops, 0, NumElts); + + SDValue Res = Ops[0]; + for (unsigned i = 1; i < NumElts; i++) + Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Node->getFlags()); + + // Result type may be wider than element type. + if (EltVT != Node->getValueType(0)) + Res = DAG.getNode(ISD::ANY_EXTEND, dl, Node->getValueType(0), Res); + return Res; } |