diff options
Diffstat (limited to 'contrib/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp')
| -rw-r--r-- | contrib/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 1926 |
1 files changed, 1466 insertions, 460 deletions
diff --git a/contrib/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/contrib/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index e317268fa5f4..a2f05c1e3cef 100644 --- a/contrib/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/contrib/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -55,10 +55,12 @@ bool TargetLowering::isInTailCallPosition(SelectionDAG &DAG, SDNode *Node, const Function &F = DAG.getMachineFunction().getFunction(); // Conservatively require the attributes of the call to match those of - // the return. Ignore noalias because it doesn't affect the call sequence. + // the return. Ignore NoAlias and NonNull because they don't affect the + // call sequence. AttributeList CallerAttrs = F.getAttributes(); if (AttrBuilder(CallerAttrs, AttributeList::ReturnIndex) .removeAttribute(Attribute::NoAlias) + .removeAttribute(Attribute::NonNull) .hasAttributes()) return false; @@ -429,87 +431,56 @@ bool TargetLowering::ShrinkDemandedOp(SDValue Op, unsigned BitWidth, return false; } -bool -TargetLowering::SimplifyDemandedBits(SDNode *User, unsigned OpIdx, - const APInt &Demanded, - DAGCombinerInfo &DCI, - TargetLoweringOpt &TLO) const { - SDValue Op = User->getOperand(OpIdx); - KnownBits Known; - - if (!SimplifyDemandedBits(Op, Demanded, Known, TLO, 0, true)) - return false; - - - // Old will not always be the same as Op. For example: - // - // Demanded = 0xffffff - // Op = i64 truncate (i32 and x, 0xffffff) - // In this case simplify demand bits will want to replace the 'and' node - // with the value 'x', which will give us: - // Old = i32 and x, 0xffffff - // New = x - if (TLO.Old.hasOneUse()) { - // For the one use case, we just commit the change. - DCI.CommitTargetLoweringOpt(TLO); - return true; - } - - // If Old has more than one use then it must be Op, because the - // AssumeSingleUse flag is not propogated to recursive calls of - // SimplifyDemanded bits, so the only node with multiple use that - // it will attempt to combine will be Op. - assert(TLO.Old == Op); - - SmallVector <SDValue, 4> NewOps; - for (unsigned i = 0, e = User->getNumOperands(); i != e; ++i) { - if (i == OpIdx) { - NewOps.push_back(TLO.New); - continue; - } - NewOps.push_back(User->getOperand(i)); - } - User = TLO.DAG.UpdateNodeOperands(User, NewOps); - // Op has less users now, so we may be able to perform additional combines - // with it. - DCI.AddToWorklist(Op.getNode()); - // User's operands have been updated, so we may be able to do new combines - // with it. - DCI.AddToWorklist(User); - return true; -} - -bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedMask, +bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits, DAGCombinerInfo &DCI) const { - SelectionDAG &DAG = DCI.DAG; TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), !DCI.isBeforeLegalizeOps()); KnownBits Known; - bool Simplified = SimplifyDemandedBits(Op, DemandedMask, Known, TLO); - if (Simplified) + bool Simplified = SimplifyDemandedBits(Op, DemandedBits, Known, TLO); + if (Simplified) { + DCI.AddToWorklist(Op.getNode()); DCI.CommitTargetLoweringOpt(TLO); + } return Simplified; } -/// Look at Op. At this point, we know that only the DemandedMask bits of the +bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits, + KnownBits &Known, + TargetLoweringOpt &TLO, + unsigned Depth, + bool AssumeSingleUse) const { + EVT VT = Op.getValueType(); + APInt DemandedElts = VT.isVector() + ? APInt::getAllOnesValue(VT.getVectorNumElements()) + : APInt(1, 1); + return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth, + AssumeSingleUse); +} + +/// Look at Op. At this point, we know that only the OriginalDemandedBits of the /// result of Op are ever used downstream. If we can use this information to /// simplify Op, create a new simplified DAG node and return true, returning the /// original and new nodes in Old and New. Otherwise, analyze the expression and /// return a mask of Known bits for the expression (used to simplify the /// caller). The Known bits may only be accurate for those bits in the -/// DemandedMask. -bool TargetLowering::SimplifyDemandedBits(SDValue Op, - const APInt &DemandedMask, - KnownBits &Known, - TargetLoweringOpt &TLO, - unsigned Depth, - bool AssumeSingleUse) const { - unsigned BitWidth = DemandedMask.getBitWidth(); +/// OriginalDemandedBits and OriginalDemandedElts. +bool TargetLowering::SimplifyDemandedBits( + SDValue Op, const APInt &OriginalDemandedBits, + const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, + unsigned Depth, bool AssumeSingleUse) const { + unsigned BitWidth = OriginalDemandedBits.getBitWidth(); assert(Op.getScalarValueSizeInBits() == BitWidth && "Mask size mismatches value type size!"); - APInt NewMask = DemandedMask; + + unsigned NumElts = OriginalDemandedElts.getBitWidth(); + assert((!Op.getValueType().isVector() || + NumElts == Op.getValueType().getVectorNumElements()) && + "Unexpected vector size"); + + APInt DemandedBits = OriginalDemandedBits; + APInt DemandedElts = OriginalDemandedElts; SDLoc dl(Op); auto &DL = TLO.DAG.getDataLayout(); @@ -529,18 +500,19 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, if (Depth != 0) { // If not at the root, Just compute the Known bits to // simplify things downstream. - TLO.DAG.computeKnownBits(Op, Known, Depth); + Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth); return false; } // If this is the root being simplified, allow it to have multiple uses, - // just set the NewMask to all bits. - NewMask = APInt::getAllOnesValue(BitWidth); - } else if (DemandedMask == 0) { - // Not demanding any bits from Op. + // just set the DemandedBits/Elts to all bits. + DemandedBits = APInt::getAllOnesValue(BitWidth); + DemandedElts = APInt::getAllOnesValue(NumElts); + } else if (OriginalDemandedBits == 0 || OriginalDemandedElts == 0) { + // Not demanding any bits/elts from Op. if (!Op.isUndef()) return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT)); return false; - } else if (Depth == 6) { // Limit search depth. + } else if (Depth == 6) { // Limit search depth. return false; } @@ -570,24 +542,90 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, Known.One &= Known2.One; Known.Zero &= Known2.Zero; } - return false; // Don't fall through, will infinitely loop. - case ISD::AND: + return false; // Don't fall through, will infinitely loop. + case ISD::CONCAT_VECTORS: { + Known.Zero.setAllBits(); + Known.One.setAllBits(); + EVT SubVT = Op.getOperand(0).getValueType(); + unsigned NumSubVecs = Op.getNumOperands(); + unsigned NumSubElts = SubVT.getVectorNumElements(); + for (unsigned i = 0; i != NumSubVecs; ++i) { + APInt DemandedSubElts = + DemandedElts.extractBits(NumSubElts, i * NumSubElts); + if (SimplifyDemandedBits(Op.getOperand(i), DemandedBits, DemandedSubElts, + Known2, TLO, Depth + 1)) + return true; + // Known bits are shared by every demanded subvector element. + if (!!DemandedSubElts) { + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } + } + break; + } + case ISD::VECTOR_SHUFFLE: { + ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask(); + + // Collect demanded elements from shuffle operands.. + APInt DemandedLHS(NumElts, 0); + APInt DemandedRHS(NumElts, 0); + for (unsigned i = 0; i != NumElts; ++i) { + if (!DemandedElts[i]) + continue; + int M = ShuffleMask[i]; + if (M < 0) { + // For UNDEF elements, we don't know anything about the common state of + // the shuffle result. + DemandedLHS.clearAllBits(); + DemandedRHS.clearAllBits(); + break; + } + assert(0 <= M && M < (int)(2 * NumElts) && "Shuffle index out of range"); + if (M < (int)NumElts) + DemandedLHS.setBit(M); + else + DemandedRHS.setBit(M - NumElts); + } + + if (!!DemandedLHS || !!DemandedRHS) { + Known.Zero.setAllBits(); + Known.One.setAllBits(); + if (!!DemandedLHS) { + if (SimplifyDemandedBits(Op.getOperand(0), DemandedBits, DemandedLHS, + Known2, TLO, Depth + 1)) + return true; + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } + if (!!DemandedRHS) { + if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, DemandedRHS, + Known2, TLO, Depth + 1)) + return true; + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } + } + break; + } + case ISD::AND: { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + // If the RHS is a constant, check to see if the LHS would be zero without // using the bits from the RHS. Below, we use knowledge about the RHS to // simplify the LHS, here we're using information from the LHS to simplify // the RHS. - if (ConstantSDNode *RHSC = isConstOrConstSplat(Op.getOperand(1))) { - SDValue Op0 = Op.getOperand(0); - KnownBits LHSKnown; + if (ConstantSDNode *RHSC = isConstOrConstSplat(Op1)) { // Do not increment Depth here; that can cause an infinite loop. - TLO.DAG.computeKnownBits(Op0, LHSKnown, Depth); + KnownBits LHSKnown = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth); // If the LHS already has zeros where RHSC does, this 'and' is dead. - if ((LHSKnown.Zero & NewMask) == (~RHSC->getAPIntValue() & NewMask)) + if ((LHSKnown.Zero & DemandedBits) == + (~RHSC->getAPIntValue() & DemandedBits)) return TLO.CombineTo(Op, Op0); // If any of the set bits in the RHS are known zero on the LHS, shrink // the constant. - if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & NewMask, TLO)) + if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits, TLO)) return true; // Bitwise-not (xor X, -1) is a special case: we don't usually shrink its @@ -597,34 +635,33 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // and (xor (srl X, 31), -1), 1 --> xor (srl X, 31), 1 if (isBitwiseNot(Op0) && Op0.hasOneUse() && LHSKnown.One == ~RHSC->getAPIntValue()) { - SDValue Xor = TLO.DAG.getNode(ISD::XOR, dl, VT, Op0.getOperand(0), - Op.getOperand(1)); + SDValue Xor = TLO.DAG.getNode(ISD::XOR, dl, VT, Op0.getOperand(0), Op1); return TLO.CombineTo(Op, Xor); } } - if (SimplifyDemandedBits(Op.getOperand(1), NewMask, Known, TLO, Depth+1)) + if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - if (SimplifyDemandedBits(Op.getOperand(0), ~Known.Zero & NewMask, - Known2, TLO, Depth+1)) + if (SimplifyDemandedBits(Op0, ~Known.Zero & DemandedBits, DemandedElts, Known2, TLO, + Depth + 1)) return true; assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); // If all of the demanded bits are known one on one side, return the other. // These bits cannot contribute to the result of the 'and'. - if (NewMask.isSubsetOf(Known2.Zero | Known.One)) - return TLO.CombineTo(Op, Op.getOperand(0)); - if (NewMask.isSubsetOf(Known.Zero | Known2.One)) - return TLO.CombineTo(Op, Op.getOperand(1)); + if (DemandedBits.isSubsetOf(Known2.Zero | Known.One)) + return TLO.CombineTo(Op, Op0); + if (DemandedBits.isSubsetOf(Known.Zero | Known2.One)) + return TLO.CombineTo(Op, Op1); // If all of the demanded bits in the inputs are known zeros, return zero. - if (NewMask.isSubsetOf(Known.Zero | Known2.Zero)) + if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero)) return TLO.CombineTo(Op, TLO.DAG.getConstant(0, dl, VT)); // If the RHS is a constant, see if we can simplify it. - if (ShrinkDemandedConstant(Op, ~Known2.Zero & NewMask, TLO)) + if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, TLO)) return true; // If the operation can be done in a smaller type, do so. - if (ShrinkDemandedOp(Op, BitWidth, NewMask, TLO)) + if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) return true; // Output known-1 bits are only known if set in both the LHS & RHS. @@ -632,26 +669,30 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // Output known-0 are known to be clear if zero in either the LHS | RHS. Known.Zero |= Known2.Zero; break; - case ISD::OR: - if (SimplifyDemandedBits(Op.getOperand(1), NewMask, Known, TLO, Depth+1)) + } + case ISD::OR: { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + + if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - if (SimplifyDemandedBits(Op.getOperand(0), ~Known.One & NewMask, - Known2, TLO, Depth+1)) + if (SimplifyDemandedBits(Op0, ~Known.One & DemandedBits, DemandedElts, Known2, TLO, + Depth + 1)) return true; assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); // If all of the demanded bits are known zero on one side, return the other. // These bits cannot contribute to the result of the 'or'. - if (NewMask.isSubsetOf(Known2.One | Known.Zero)) - return TLO.CombineTo(Op, Op.getOperand(0)); - if (NewMask.isSubsetOf(Known.One | Known2.Zero)) - return TLO.CombineTo(Op, Op.getOperand(1)); + if (DemandedBits.isSubsetOf(Known2.One | Known.Zero)) + return TLO.CombineTo(Op, Op0); + if (DemandedBits.isSubsetOf(Known.One | Known2.Zero)) + return TLO.CombineTo(Op, Op1); // If the RHS is a constant, see if we can simplify it. - if (ShrinkDemandedConstant(Op, NewMask, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) return true; // If the operation can be done in a smaller type, do so. - if (ShrinkDemandedOp(Op, BitWidth, NewMask, TLO)) + if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) return true; // Output known-0 bits are only known if clear in both the LHS & RHS. @@ -659,78 +700,81 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // Output known-1 are known to be set if set in either the LHS | RHS. Known.One |= Known2.One; break; + } case ISD::XOR: { - if (SimplifyDemandedBits(Op.getOperand(1), NewMask, Known, TLO, Depth+1)) + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + + if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - if (SimplifyDemandedBits(Op.getOperand(0), NewMask, Known2, TLO, Depth+1)) + if (SimplifyDemandedBits(Op0, DemandedBits, DemandedElts, Known2, TLO, Depth + 1)) return true; assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); // If all of the demanded bits are known zero on one side, return the other. // These bits cannot contribute to the result of the 'xor'. - if (NewMask.isSubsetOf(Known.Zero)) - return TLO.CombineTo(Op, Op.getOperand(0)); - if (NewMask.isSubsetOf(Known2.Zero)) - return TLO.CombineTo(Op, Op.getOperand(1)); + if (DemandedBits.isSubsetOf(Known.Zero)) + return TLO.CombineTo(Op, Op0); + if (DemandedBits.isSubsetOf(Known2.Zero)) + return TLO.CombineTo(Op, Op1); // If the operation can be done in a smaller type, do so. - if (ShrinkDemandedOp(Op, BitWidth, NewMask, TLO)) + if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) return true; // If all of the unknown bits are known to be zero on one side or the other // (but not both) turn this into an *inclusive* or. // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0 - if ((NewMask & ~Known.Zero & ~Known2.Zero) == 0) - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, - Op.getOperand(0), - Op.getOperand(1))); + if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero)) + return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1)); // Output known-0 bits are known if clear or set in both the LHS & RHS. KnownOut.Zero = (Known.Zero & Known2.Zero) | (Known.One & Known2.One); // Output known-1 are known to be set if set in only one of the LHS, RHS. KnownOut.One = (Known.Zero & Known2.One) | (Known.One & Known2.Zero); - // If all of the demanded bits on one side are known, and all of the set - // bits on that side are also known to be set on the other side, turn this - // into an AND, as we know the bits will be cleared. - // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2 - // NB: it is okay if more bits are known than are requested - if (NewMask.isSubsetOf(Known.Zero|Known.One)) { // all known on one side - if (Known.One == Known2.One) { // set bits are the same on both sides - SDValue ANDC = TLO.DAG.getConstant(~Known.One & NewMask, dl, VT); - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, - Op.getOperand(0), ANDC)); + if (ConstantSDNode *C = isConstOrConstSplat(Op1)) { + // If one side is a constant, and all of the known set bits on the other + // side are also set in the constant, turn this into an AND, as we know + // the bits will be cleared. + // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2 + // NB: it is okay if more bits are known than are requested + if (C->getAPIntValue() == Known2.One) { + SDValue ANDC = + TLO.DAG.getConstant(~C->getAPIntValue() & DemandedBits, dl, VT); + return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, ANDC)); } - } - // If the RHS is a constant, see if we can change it. Don't alter a -1 - // constant because that's a 'not' op, and that is better for combining and - // codegen. - ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(1)); - if (C && !C->isAllOnesValue()) { - if (NewMask.isSubsetOf(C->getAPIntValue())) { - // We're flipping all demanded bits. Flip the undemanded bits too. - SDValue New = TLO.DAG.getNOT(dl, Op.getOperand(0), VT); - return TLO.CombineTo(Op, New); + // If the RHS is a constant, see if we can change it. Don't alter a -1 + // constant because that's a 'not' op, and that is better for combining + // and codegen. + if (!C->isAllOnesValue()) { + if (DemandedBits.isSubsetOf(C->getAPIntValue())) { + // We're flipping all demanded bits. Flip the undemanded bits too. + SDValue New = TLO.DAG.getNOT(dl, Op0, VT); + return TLO.CombineTo(Op, New); + } + // If we can't turn this into a 'not', try to shrink the constant. + if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) + return true; } - // If we can't turn this into a 'not', try to shrink the constant. - if (ShrinkDemandedConstant(Op, NewMask, TLO)) - return true; } Known = std::move(KnownOut); break; } case ISD::SELECT: - if (SimplifyDemandedBits(Op.getOperand(2), NewMask, Known, TLO, Depth+1)) + if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, Known, TLO, + Depth + 1)) return true; - if (SimplifyDemandedBits(Op.getOperand(1), NewMask, Known2, TLO, Depth+1)) + if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, Known2, TLO, + Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); // If the operands are constants, see if we can simplify them. - if (ShrinkDemandedConstant(Op, NewMask, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) return true; // Only known if known in both the LHS and RHS. @@ -738,15 +782,17 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, Known.Zero &= Known2.Zero; break; case ISD::SELECT_CC: - if (SimplifyDemandedBits(Op.getOperand(3), NewMask, Known, TLO, Depth+1)) + if (SimplifyDemandedBits(Op.getOperand(3), DemandedBits, Known, TLO, + Depth + 1)) return true; - if (SimplifyDemandedBits(Op.getOperand(2), NewMask, Known2, TLO, Depth+1)) + if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, Known2, TLO, + Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); // If the operands are constants, see if we can simplify them. - if (ShrinkDemandedConstant(Op, NewMask, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) return true; // Only known if known in both the LHS and RHS. @@ -760,7 +806,8 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // If (1) we only need the sign-bit, (2) the setcc operands are the same // width as the setcc result, and (3) the result of a setcc conforms to 0 or // -1, we may be able to bypass the setcc. - if (NewMask.isSignMask() && Op0.getScalarValueSizeInBits() == BitWidth && + if (DemandedBits.isSignMask() && + Op0.getScalarValueSizeInBits() == BitWidth && getBooleanContents(VT) == BooleanContent::ZeroOrNegativeOneBooleanContent) { // If we're testing X < 0, then this compare isn't needed - just use X! @@ -780,10 +827,11 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, Known.Zero.setBitsFrom(1); break; } - case ISD::SHL: - if (ConstantSDNode *SA = isConstOrConstSplat(Op.getOperand(1))) { - SDValue InOp = Op.getOperand(0); + case ISD::SHL: { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + if (ConstantSDNode *SA = isConstOrConstSplat(Op1)) { // If the shift count is an invalid immediate, don't do anything. if (SA->getAPIntValue().uge(BitWidth)) break; @@ -793,90 +841,91 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a // single shift. We can do this if the bottom bits (which are shifted // out) are never demanded. - if (InOp.getOpcode() == ISD::SRL) { - if (ConstantSDNode *SA2 = isConstOrConstSplat(InOp.getOperand(1))) { - if (ShAmt && (NewMask & APInt::getLowBitsSet(BitWidth, ShAmt)) == 0) { + if (Op0.getOpcode() == ISD::SRL) { + if (ShAmt && + (DemandedBits & APInt::getLowBitsSet(BitWidth, ShAmt)) == 0) { + if (ConstantSDNode *SA2 = isConstOrConstSplat(Op0.getOperand(1))) { if (SA2->getAPIntValue().ult(BitWidth)) { unsigned C1 = SA2->getZExtValue(); unsigned Opc = ISD::SHL; - int Diff = ShAmt-C1; + int Diff = ShAmt - C1; if (Diff < 0) { Diff = -Diff; Opc = ISD::SRL; } - SDValue NewSA = - TLO.DAG.getConstant(Diff, dl, Op.getOperand(1).getValueType()); - return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, - InOp.getOperand(0), - NewSA)); + SDValue NewSA = TLO.DAG.getConstant(Diff, dl, Op1.getValueType()); + return TLO.CombineTo( + Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA)); } } } } - if (SimplifyDemandedBits(InOp, NewMask.lshr(ShAmt), Known, TLO, Depth+1)) + if (SimplifyDemandedBits(Op0, DemandedBits.lshr(ShAmt), DemandedElts, Known, TLO, + Depth + 1)) return true; // Convert (shl (anyext x, c)) to (anyext (shl x, c)) if the high bits // are not demanded. This will likely allow the anyext to be folded away. - if (InOp.getNode()->getOpcode() == ISD::ANY_EXTEND) { - SDValue InnerOp = InOp.getOperand(0); + if (Op0.getOpcode() == ISD::ANY_EXTEND) { + SDValue InnerOp = Op0.getOperand(0); EVT InnerVT = InnerOp.getValueType(); unsigned InnerBits = InnerVT.getScalarSizeInBits(); - if (ShAmt < InnerBits && NewMask.getActiveBits() <= InnerBits && + if (ShAmt < InnerBits && DemandedBits.getActiveBits() <= InnerBits && isTypeDesirableForOp(ISD::SHL, InnerVT)) { EVT ShTy = getShiftAmountTy(InnerVT, DL); if (!APInt(BitWidth, ShAmt).isIntN(ShTy.getSizeInBits())) ShTy = InnerVT; SDValue NarrowShl = - TLO.DAG.getNode(ISD::SHL, dl, InnerVT, InnerOp, - TLO.DAG.getConstant(ShAmt, dl, ShTy)); - return - TLO.CombineTo(Op, - TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl)); + TLO.DAG.getNode(ISD::SHL, dl, InnerVT, InnerOp, + TLO.DAG.getConstant(ShAmt, dl, ShTy)); + return TLO.CombineTo( + Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl)); } // Repeat the SHL optimization above in cases where an extension // intervenes: (shl (anyext (shr x, c1)), c2) to // (shl (anyext x), c2-c1). This requires that the bottom c1 bits // aren't demanded (as above) and that the shifted upper c1 bits of // x aren't demanded. - if (InOp.hasOneUse() && InnerOp.getOpcode() == ISD::SRL && + if (Op0.hasOneUse() && InnerOp.getOpcode() == ISD::SRL && InnerOp.hasOneUse()) { - if (ConstantSDNode *SA2 = isConstOrConstSplat(InnerOp.getOperand(1))) { + if (ConstantSDNode *SA2 = + isConstOrConstSplat(InnerOp.getOperand(1))) { unsigned InnerShAmt = SA2->getLimitedValue(InnerBits); - if (InnerShAmt < ShAmt && - InnerShAmt < InnerBits && - NewMask.getActiveBits() <= (InnerBits - InnerShAmt + ShAmt) && - NewMask.countTrailingZeros() >= ShAmt) { - SDValue NewSA = - TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, - Op.getOperand(1).getValueType()); + if (InnerShAmt < ShAmt && InnerShAmt < InnerBits && + DemandedBits.getActiveBits() <= + (InnerBits - InnerShAmt + ShAmt) && + DemandedBits.countTrailingZeros() >= ShAmt) { + SDValue NewSA = TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, + Op1.getValueType()); SDValue NewExt = TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, InnerOp.getOperand(0)); - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl, VT, - NewExt, NewSA)); + return TLO.CombineTo( + Op, TLO.DAG.getNode(ISD::SHL, dl, VT, NewExt, NewSA)); } } } } Known.Zero <<= ShAmt; - Known.One <<= ShAmt; + Known.One <<= ShAmt; // low bits known zero. Known.Zero.setLowBits(ShAmt); } break; - case ISD::SRL: - if (ConstantSDNode *SA = isConstOrConstSplat(Op.getOperand(1))) { - SDValue InOp = Op.getOperand(0); + } + case ISD::SRL: { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + if (ConstantSDNode *SA = isConstOrConstSplat(Op1)) { // If the shift count is an invalid immediate, don't do anything. if (SA->getAPIntValue().uge(BitWidth)) break; unsigned ShAmt = SA->getZExtValue(); - APInt InDemandedMask = (NewMask << ShAmt); + APInt InDemandedMask = (DemandedBits << ShAmt); // If the shift is exact, then it does demand the low bits (and knows that // they are zero). @@ -886,56 +935,56 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a // single shift. We can do this if the top bits (which are shifted out) // are never demanded. - if (InOp.getOpcode() == ISD::SHL) { - if (ConstantSDNode *SA2 = isConstOrConstSplat(InOp.getOperand(1))) { + if (Op0.getOpcode() == ISD::SHL) { + if (ConstantSDNode *SA2 = isConstOrConstSplat(Op0.getOperand(1))) { if (ShAmt && - (NewMask & APInt::getHighBitsSet(BitWidth, ShAmt)) == 0) { + (DemandedBits & APInt::getHighBitsSet(BitWidth, ShAmt)) == 0) { if (SA2->getAPIntValue().ult(BitWidth)) { unsigned C1 = SA2->getZExtValue(); unsigned Opc = ISD::SRL; - int Diff = ShAmt-C1; + int Diff = ShAmt - C1; if (Diff < 0) { Diff = -Diff; Opc = ISD::SHL; } - SDValue NewSA = - TLO.DAG.getConstant(Diff, dl, Op.getOperand(1).getValueType()); - return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, - InOp.getOperand(0), - NewSA)); + SDValue NewSA = TLO.DAG.getConstant(Diff, dl, Op1.getValueType()); + return TLO.CombineTo( + Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA)); } } } } // Compute the new bits that are at the top now. - if (SimplifyDemandedBits(InOp, InDemandedMask, Known, TLO, Depth+1)) + if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero.lshrInPlace(ShAmt); Known.One.lshrInPlace(ShAmt); - Known.Zero.setHighBits(ShAmt); // High bits known zero. + Known.Zero.setHighBits(ShAmt); // High bits known zero. } break; - case ISD::SRA: + } + case ISD::SRA: { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + // If this is an arithmetic shift right and only the low-bit is set, we can // always convert this into a logical shr, even if the shift amount is // variable. The low bit of the shift cannot be an input sign bit unless // the shift amount is >= the size of the datatype, which is undefined. - if (NewMask.isOneValue()) - return TLO.CombineTo(Op, - TLO.DAG.getNode(ISD::SRL, dl, VT, Op.getOperand(0), - Op.getOperand(1))); + if (DemandedBits.isOneValue()) + return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1)); - if (ConstantSDNode *SA = isConstOrConstSplat(Op.getOperand(1))) { + if (ConstantSDNode *SA = isConstOrConstSplat(Op1)) { // If the shift count is an invalid immediate, don't do anything. if (SA->getAPIntValue().uge(BitWidth)) break; unsigned ShAmt = SA->getZExtValue(); - APInt InDemandedMask = (NewMask << ShAmt); + APInt InDemandedMask = (DemandedBits << ShAmt); // If the shift is exact, then it does demand the low bits (and knows that // they are zero). @@ -944,11 +993,10 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // If any of the demanded bits are produced by the sign extension, we also // demand the input sign bit. - if (NewMask.countLeadingZeros() < ShAmt) + if (DemandedBits.countLeadingZeros() < ShAmt) InDemandedMask.setSignBit(); - if (SimplifyDemandedBits(Op.getOperand(0), InDemandedMask, Known, TLO, - Depth+1)) + if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero.lshrInPlace(ShAmt); @@ -957,22 +1005,19 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // If the input sign bit is known to be zero, or if none of the top bits // are demanded, turn this into an unsigned shift right. if (Known.Zero[BitWidth - ShAmt - 1] || - NewMask.countLeadingZeros() >= ShAmt) { + DemandedBits.countLeadingZeros() >= ShAmt) { SDNodeFlags Flags; Flags.setExact(Op->getFlags().hasExact()); - return TLO.CombineTo(Op, - TLO.DAG.getNode(ISD::SRL, dl, VT, Op.getOperand(0), - Op.getOperand(1), Flags)); + return TLO.CombineTo( + Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1, Flags)); } - int Log2 = NewMask.exactLogBase2(); + int Log2 = DemandedBits.exactLogBase2(); if (Log2 >= 0) { // The bit must come from the sign. SDValue NewSA = - TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, - Op.getOperand(1).getValueType()); - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, - Op.getOperand(0), NewSA)); + TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, Op1.getValueType()); + return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, NewSA)); } if (Known.One[BitWidth - ShAmt - 1]) @@ -980,15 +1025,16 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, Known.One.setHighBits(ShAmt); } break; + } case ISD::SIGN_EXTEND_INREG: { + SDValue Op0 = Op.getOperand(0); EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT(); unsigned ExVTBits = ExVT.getScalarSizeInBits(); // If we only care about the highest bit, don't bother shifting right. - if (NewMask.isSignMask()) { - SDValue InOp = Op.getOperand(0); + if (DemandedBits.isSignMask()) { bool AlreadySignExtended = - TLO.DAG.ComputeNumSignBits(InOp) >= BitWidth-ExVTBits+1; + TLO.DAG.ComputeNumSignBits(Op0) >= BitWidth - ExVTBits + 1; // However if the input is already sign extended we expect the sign // extension to be dropped altogether later and do not simplify. if (!AlreadySignExtended) { @@ -998,25 +1044,24 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, if (TLO.LegalTypes() && !ShiftAmtTy.isVector()) ShiftAmtTy = getShiftAmountTy(ShiftAmtTy, DL); - SDValue ShiftAmt = TLO.DAG.getConstant(BitWidth - ExVTBits, dl, - ShiftAmtTy); - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl, VT, InOp, - ShiftAmt)); + SDValue ShiftAmt = + TLO.DAG.getConstant(BitWidth - ExVTBits, dl, ShiftAmtTy); + return TLO.CombineTo(Op, + TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, ShiftAmt)); } } // If none of the extended bits are demanded, eliminate the sextinreg. - if (NewMask.getActiveBits() <= ExVTBits) - return TLO.CombineTo(Op, Op.getOperand(0)); + if (DemandedBits.getActiveBits() <= ExVTBits) + return TLO.CombineTo(Op, Op0); - APInt InputDemandedBits = NewMask.getLoBits(ExVTBits); + APInt InputDemandedBits = DemandedBits.getLoBits(ExVTBits); // Since the sign extended bits are demanded, we know that the sign // bit is demanded. InputDemandedBits.setBit(ExVTBits - 1); - if (SimplifyDemandedBits(Op.getOperand(0), InputDemandedBits, - Known, TLO, Depth+1)) + if (SimplifyDemandedBits(Op0, InputDemandedBits, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); @@ -1025,14 +1070,14 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // If the input sign bit is known zero, convert this into a zero extension. if (Known.Zero[ExVTBits - 1]) - return TLO.CombineTo(Op, TLO.DAG.getZeroExtendInReg( - Op.getOperand(0), dl, ExVT.getScalarType())); + return TLO.CombineTo( + Op, TLO.DAG.getZeroExtendInReg(Op0, dl, ExVT.getScalarType())); APInt Mask = APInt::getLowBitsSet(BitWidth, ExVTBits); - if (Known.One[ExVTBits - 1]) { // Input sign bit known set + if (Known.One[ExVTBits - 1]) { // Input sign bit known set Known.One.setBitsFrom(ExVTBits); Known.Zero &= Mask; - } else { // Input sign bit unknown + } else { // Input sign bit unknown Known.Zero &= Mask; Known.One &= Mask; } @@ -1042,8 +1087,8 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, EVT HalfVT = Op.getOperand(0).getValueType(); unsigned HalfBitWidth = HalfVT.getScalarSizeInBits(); - APInt MaskLo = NewMask.getLoBits(HalfBitWidth).trunc(HalfBitWidth); - APInt MaskHi = NewMask.getHiBits(HalfBitWidth).trunc(HalfBitWidth); + APInt MaskLo = DemandedBits.getLoBits(HalfBitWidth).trunc(HalfBitWidth); + APInt MaskHi = DemandedBits.getHiBits(HalfBitWidth).trunc(HalfBitWidth); KnownBits KnownLo, KnownHi; @@ -1061,36 +1106,35 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, break; } case ISD::ZERO_EXTEND: { - unsigned OperandBitWidth = Op.getOperand(0).getScalarValueSizeInBits(); + SDValue Src = Op.getOperand(0); + unsigned InBits = Src.getScalarValueSizeInBits(); // If none of the top bits are demanded, convert this into an any_extend. - if (NewMask.getActiveBits() <= OperandBitWidth) - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, - Op.getOperand(0))); + if (DemandedBits.getActiveBits() <= InBits) + return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, Src)); - APInt InMask = NewMask.trunc(OperandBitWidth); - if (SimplifyDemandedBits(Op.getOperand(0), InMask, Known, TLO, Depth+1)) + APInt InDemandedBits = DemandedBits.trunc(InBits); + if (SimplifyDemandedBits(Src, InDemandedBits, Known, TLO, Depth+1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known = Known.zext(BitWidth); - Known.Zero.setBitsFrom(OperandBitWidth); + Known.Zero.setBitsFrom(InBits); break; } case ISD::SIGN_EXTEND: { - unsigned InBits = Op.getOperand(0).getValueType().getScalarSizeInBits(); + SDValue Src = Op.getOperand(0); + unsigned InBits = Src.getScalarValueSizeInBits(); // If none of the top bits are demanded, convert this into an any_extend. - if (NewMask.getActiveBits() <= InBits) - return TLO.CombineTo(Op,TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, - Op.getOperand(0))); + if (DemandedBits.getActiveBits() <= InBits) + return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, Src)); // Since some of the sign extended bits are demanded, we know that the sign // bit is demanded. - APInt InDemandedBits = NewMask.trunc(InBits); + APInt InDemandedBits = DemandedBits.trunc(InBits); InDemandedBits.setBit(InBits - 1); - if (SimplifyDemandedBits(Op.getOperand(0), InDemandedBits, Known, TLO, - Depth+1)) + if (SimplifyDemandedBits(Src, InDemandedBits, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); // If the sign bit is known one, the top bits match. @@ -1098,34 +1142,55 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // If the sign bit is known zero, convert this to a zero extend. if (Known.isNonNegative()) - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, - Op.getOperand(0))); + return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Src)); + break; + } + case ISD::SIGN_EXTEND_VECTOR_INREG: { + // TODO - merge this with SIGN_EXTEND above? + SDValue Src = Op.getOperand(0); + unsigned InBits = Src.getScalarValueSizeInBits(); + + APInt InDemandedBits = DemandedBits.trunc(InBits); + + // If some of the sign extended bits are demanded, we know that the sign + // bit is demanded. + if (InBits < DemandedBits.getActiveBits()) + InDemandedBits.setBit(InBits - 1); + + if (SimplifyDemandedBits(Src, InDemandedBits, Known, TLO, Depth + 1)) + return true; + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + // If the sign bit is known one, the top bits match. + Known = Known.sext(BitWidth); break; } case ISD::ANY_EXTEND: { - unsigned OperandBitWidth = Op.getOperand(0).getScalarValueSizeInBits(); - APInt InMask = NewMask.trunc(OperandBitWidth); - if (SimplifyDemandedBits(Op.getOperand(0), InMask, Known, TLO, Depth+1)) + SDValue Src = Op.getOperand(0); + unsigned InBits = Src.getScalarValueSizeInBits(); + APInt InDemandedBits = DemandedBits.trunc(InBits); + if (SimplifyDemandedBits(Src, InDemandedBits, Known, TLO, Depth+1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known = Known.zext(BitWidth); break; } case ISD::TRUNCATE: { + SDValue Src = Op.getOperand(0); + // Simplify the input, using demanded bit information, and compute the known // zero/one bits live out. - unsigned OperandBitWidth = Op.getOperand(0).getScalarValueSizeInBits(); - APInt TruncMask = NewMask.zext(OperandBitWidth); - if (SimplifyDemandedBits(Op.getOperand(0), TruncMask, Known, TLO, Depth+1)) + unsigned OperandBitWidth = Src.getScalarValueSizeInBits(); + APInt TruncMask = DemandedBits.zext(OperandBitWidth); + if (SimplifyDemandedBits(Src, TruncMask, Known, TLO, Depth + 1)) return true; Known = Known.trunc(BitWidth); // If the input is only used by this truncate, see if we can shrink it based // on the known demanded bits. - if (Op.getOperand(0).getNode()->hasOneUse()) { - SDValue In = Op.getOperand(0); - switch (In.getOpcode()) { - default: break; + if (Src.getNode()->hasOneUse()) { + switch (Src.getOpcode()) { + default: + break; case ISD::SRL: // Shrink SRL by a constant if none of the high bits shifted in are // demanded. @@ -1133,10 +1198,10 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // Do not turn (vt1 truncate (vt2 srl)) into (vt1 srl) if vt1 is // undesirable. break; - ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(In.getOperand(1)); + ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Src.getOperand(1)); if (!ShAmt) break; - SDValue Shift = In.getOperand(1); + SDValue Shift = Src.getOperand(1); if (TLO.LegalTypes()) { uint64_t ShVal = ShAmt->getZExtValue(); Shift = TLO.DAG.getConstant(ShVal, dl, getShiftAmountTy(VT, DL)); @@ -1148,13 +1213,13 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, HighBits.lshrInPlace(ShAmt->getZExtValue()); HighBits = HighBits.trunc(BitWidth); - if (!(HighBits & NewMask)) { + if (!(HighBits & DemandedBits)) { // None of the shifted in bits are needed. Add a truncate of the // shift input, then shift it. - SDValue NewTrunc = TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, - In.getOperand(0)); - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, NewTrunc, - Shift)); + SDValue NewTrunc = + TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, Src.getOperand(0)); + return TLO.CombineTo( + Op, TLO.DAG.getNode(ISD::SRL, dl, VT, NewTrunc, Shift)); } } break; @@ -1169,7 +1234,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // demanded by its users. EVT ZVT = cast<VTSDNode>(Op.getOperand(1))->getVT(); APInt InMask = APInt::getLowBitsSet(BitWidth, ZVT.getSizeInBits()); - if (SimplifyDemandedBits(Op.getOperand(0), ~InMask | NewMask, + if (SimplifyDemandedBits(Op.getOperand(0), ~InMask | DemandedBits, Known, TLO, Depth+1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); @@ -1177,50 +1242,111 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, Known.Zero |= ~InMask; break; } - case ISD::BITCAST: + case ISD::EXTRACT_VECTOR_ELT: { + SDValue Src = Op.getOperand(0); + SDValue Idx = Op.getOperand(1); + unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); + unsigned EltBitWidth = Src.getScalarValueSizeInBits(); + + // Demand the bits from every vector element without a constant index. + APInt DemandedSrcElts = APInt::getAllOnesValue(NumSrcElts); + if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx)) + if (CIdx->getAPIntValue().ult(NumSrcElts)) + DemandedSrcElts = APInt::getOneBitSet(NumSrcElts, CIdx->getZExtValue()); + + // If BitWidth > EltBitWidth the value is anyext:ed. So we do not know + // anything about the extended bits. + APInt DemandedSrcBits = DemandedBits; + if (BitWidth > EltBitWidth) + DemandedSrcBits = DemandedSrcBits.trunc(EltBitWidth); + + if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts, Known2, TLO, + Depth + 1)) + return true; + + Known = Known2; + if (BitWidth > EltBitWidth) + Known = Known.zext(BitWidth); + break; + } + case ISD::BITCAST: { + SDValue Src = Op.getOperand(0); + EVT SrcVT = Src.getValueType(); + unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits(); + // If this is an FP->Int bitcast and if the sign bit is the only // thing demanded, turn this into a FGETSIGN. - if (!TLO.LegalOperations() && !VT.isVector() && - !Op.getOperand(0).getValueType().isVector() && - NewMask == APInt::getSignMask(Op.getValueSizeInBits()) && - Op.getOperand(0).getValueType().isFloatingPoint()) { + if (!TLO.LegalOperations() && !VT.isVector() && !SrcVT.isVector() && + DemandedBits == APInt::getSignMask(Op.getValueSizeInBits()) && + SrcVT.isFloatingPoint()) { bool OpVTLegal = isOperationLegalOrCustom(ISD::FGETSIGN, VT); - bool i32Legal = isOperationLegalOrCustom(ISD::FGETSIGN, MVT::i32); - if ((OpVTLegal || i32Legal) && VT.isSimple() && - Op.getOperand(0).getValueType() != MVT::f16 && - Op.getOperand(0).getValueType() != MVT::f128) { + bool i32Legal = isOperationLegalOrCustom(ISD::FGETSIGN, MVT::i32); + if ((OpVTLegal || i32Legal) && VT.isSimple() && SrcVT != MVT::f16 && + SrcVT != MVT::f128) { // Cannot eliminate/lower SHL for f128 yet. EVT Ty = OpVTLegal ? VT : MVT::i32; // Make a FGETSIGN + SHL to move the sign bit into the appropriate // place. We expect the SHL to be eliminated by other optimizations. - SDValue Sign = TLO.DAG.getNode(ISD::FGETSIGN, dl, Ty, Op.getOperand(0)); + SDValue Sign = TLO.DAG.getNode(ISD::FGETSIGN, dl, Ty, Src); unsigned OpVTSizeInBits = Op.getValueSizeInBits(); if (!OpVTLegal && OpVTSizeInBits > 32) Sign = TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Sign); unsigned ShVal = Op.getValueSizeInBits() - 1; SDValue ShAmt = TLO.DAG.getConstant(ShVal, dl, VT); - return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt)); + return TLO.CombineTo(Op, + TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt)); + } + } + // If bitcast from a vector, see if we can use SimplifyDemandedVectorElts by + // demanding the element if any bits from it are demanded. + // TODO - bigendian once we have test coverage. + // TODO - bool vectors once SimplifyDemandedVectorElts has SETCC support. + if (SrcVT.isVector() && NumSrcEltBits > 1 && + (BitWidth % NumSrcEltBits) == 0 && + TLO.DAG.getDataLayout().isLittleEndian()) { + unsigned Scale = BitWidth / NumSrcEltBits; + auto GetDemandedSubMask = [&](APInt &DemandedSubElts) -> bool { + DemandedSubElts = APInt::getNullValue(Scale); + for (unsigned i = 0; i != Scale; ++i) { + unsigned Offset = i * NumSrcEltBits; + APInt Sub = DemandedBits.extractBits(NumSrcEltBits, Offset); + if (!Sub.isNullValue()) + DemandedSubElts.setBit(i); + } + return true; + }; + + APInt DemandedSubElts; + if (GetDemandedSubMask(DemandedSubElts)) { + unsigned NumSrcElts = SrcVT.getVectorNumElements(); + APInt DemandedElts = APInt::getSplat(NumSrcElts, DemandedSubElts); + + APInt KnownUndef, KnownZero; + if (SimplifyDemandedVectorElts(Src, DemandedElts, KnownUndef, KnownZero, + TLO, Depth + 1)) + return true; } } // If this is a bitcast, let computeKnownBits handle it. Only do this on a // recursive call where Known may be useful to the caller. if (Depth > 0) { - TLO.DAG.computeKnownBits(Op, Known, Depth); + Known = TLO.DAG.computeKnownBits(Op, Depth); return false; } break; + } case ISD::ADD: case ISD::MUL: case ISD::SUB: { // Add, Sub, and Mul don't demand any bits in positions beyond that // of the highest bit demanded of them. SDValue Op0 = Op.getOperand(0), Op1 = Op.getOperand(1); - unsigned NewMaskLZ = NewMask.countLeadingZeros(); - APInt LoMask = APInt::getLowBitsSet(BitWidth, BitWidth - NewMaskLZ); - if (SimplifyDemandedBits(Op0, LoMask, Known2, TLO, Depth + 1) || - SimplifyDemandedBits(Op1, LoMask, Known2, TLO, Depth + 1) || + unsigned DemandedBitsLZ = DemandedBits.countLeadingZeros(); + APInt LoMask = APInt::getLowBitsSet(BitWidth, BitWidth - DemandedBitsLZ); + if (SimplifyDemandedBits(Op0, LoMask, DemandedElts, Known2, TLO, Depth + 1) || + SimplifyDemandedBits(Op1, LoMask, DemandedElts, Known2, TLO, Depth + 1) || // See if the operation should be performed at a smaller bit width. - ShrinkDemandedOp(Op, BitWidth, NewMask, TLO)) { + ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) { SDNodeFlags Flags = Op.getNode()->getFlags(); if (Flags.hasNoSignedWrap() || Flags.hasNoUnsignedWrap()) { // Disable the nsw and nuw flags. We can no longer guarantee that we @@ -1240,7 +1366,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, // patterns (eg, 'blsr' on x86). Don't bother changing 1 to -1 because that // is probably not useful (and could be detrimental). ConstantSDNode *C = isConstOrConstSplat(Op1); - APInt HighMask = APInt::getHighBitsSet(NewMask.getBitWidth(), NewMaskLZ); + APInt HighMask = APInt::getHighBitsSet(BitWidth, DemandedBitsLZ); if (C && !C->isAllOnesValue() && !C->isOne() && (C->getAPIntValue() | HighMask).isAllOnesValue()) { SDValue Neg1 = TLO.DAG.getAllOnesConstant(dl, VT); @@ -1257,24 +1383,34 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, LLVM_FALLTHROUGH; } default: + if (Op.getOpcode() >= ISD::BUILTIN_OP_END) { + if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, DemandedElts, + Known, TLO, Depth)) + return true; + break; + } + // Just use computeKnownBits to compute output bits. - TLO.DAG.computeKnownBits(Op, Known, Depth); + Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth); break; } // If we know the value of all of the demanded bits, return this as a // constant. - if (NewMask.isSubsetOf(Known.Zero|Known.One)) { + if (DemandedBits.isSubsetOf(Known.Zero | Known.One)) { // Avoid folding to a constant if any OpaqueConstant is involved. const SDNode *N = Op.getNode(); for (SDNodeIterator I = SDNodeIterator::begin(N), - E = SDNodeIterator::end(N); I != E; ++I) { + E = SDNodeIterator::end(N); + I != E; ++I) { SDNode *Op = *I; if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) if (C->isOpaque()) return false; } - return TLO.CombineTo(Op, TLO.DAG.getConstant(Known.One, dl, VT)); + // TODO: Handle float bits as well. + if (VT.isInteger()) + return TLO.CombineTo(Op, TLO.DAG.getConstant(Known.One, dl, VT)); } return false; @@ -1291,8 +1427,10 @@ bool TargetLowering::SimplifyDemandedVectorElts(SDValue Op, bool Simplified = SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, TLO); - if (Simplified) + if (Simplified) { + DCI.AddToWorklist(Op.getNode()); DCI.CommitTargetLoweringOpt(TLO); + } return Simplified; } @@ -1371,6 +1509,23 @@ bool TargetLowering::SimplifyDemandedVectorElts( TLO, Depth + 1)) return true; + // Try calling SimplifyDemandedBits, converting demanded elts to the bits + // of the large element. + // TODO - bigendian once we have test coverage. + if (TLO.DAG.getDataLayout().isLittleEndian()) { + unsigned SrcEltSizeInBits = SrcVT.getScalarSizeInBits(); + APInt SrcDemandedBits = APInt::getNullValue(SrcEltSizeInBits); + for (unsigned i = 0; i != NumElts; ++i) + if (DemandedElts[i]) { + unsigned Ofs = (i % Scale) * EltSizeInBits; + SrcDemandedBits.setBits(Ofs, Ofs + EltSizeInBits); + } + + KnownBits Known; + if (SimplifyDemandedBits(Src, SrcDemandedBits, Known, TLO, Depth + 1)) + return true; + } + // If the src element is zero/undef then all the output elements will be - // only demanded elements are guaranteed to be correct. for (unsigned i = 0; i != NumSrcElts; ++i) { @@ -1463,7 +1618,7 @@ bool TargetLowering::SimplifyDemandedVectorElts( EVT SubVT = Sub.getValueType(); unsigned NumSubElts = SubVT.getVectorNumElements(); const APInt& Idx = cast<ConstantSDNode>(Op.getOperand(2))->getAPIntValue(); - if (Idx.uge(NumElts - NumSubElts)) + if (Idx.ugt(NumElts - NumSubElts)) break; unsigned SubIdx = Idx.getZExtValue(); APInt SubElts = DemandedElts.extractBits(NumSubElts, SubIdx); @@ -1481,22 +1636,20 @@ bool TargetLowering::SimplifyDemandedVectorElts( break; } case ISD::EXTRACT_SUBVECTOR: { - if (!isa<ConstantSDNode>(Op.getOperand(1))) - break; SDValue Src = Op.getOperand(0); + ConstantSDNode *SubIdx = dyn_cast<ConstantSDNode>(Op.getOperand(1)); unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); - const APInt& Idx = cast<ConstantSDNode>(Op.getOperand(1))->getAPIntValue(); - if (Idx.uge(NumSrcElts - NumElts)) - break; - // Offset the demanded elts by the subvector index. - uint64_t SubIdx = Idx.getZExtValue(); - APInt SrcElts = DemandedElts.zext(NumSrcElts).shl(SubIdx); - APInt SrcUndef, SrcZero; - if (SimplifyDemandedVectorElts(Src, SrcElts, SrcUndef, SrcZero, TLO, - Depth + 1)) - return true; - KnownUndef = SrcUndef.extractBits(NumElts, SubIdx); - KnownZero = SrcZero.extractBits(NumElts, SubIdx); + if (SubIdx && SubIdx->getAPIntValue().ule(NumSrcElts - NumElts)) { + // Offset the demanded elts by the subvector index. + uint64_t Idx = SubIdx->getZExtValue(); + APInt SrcElts = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx); + APInt SrcUndef, SrcZero; + if (SimplifyDemandedVectorElts(Src, SrcElts, SrcUndef, SrcZero, TLO, + Depth + 1)) + return true; + KnownUndef = SrcUndef.extractBits(NumElts, Idx); + KnownZero = SrcZero.extractBits(NumElts, Idx); + } break; } case ISD::INSERT_VECTOR_ELT: { @@ -1510,9 +1663,10 @@ bool TargetLowering::SimplifyDemandedVectorElts( unsigned Idx = CIdx->getZExtValue(); if (!DemandedElts[Idx]) return TLO.CombineTo(Op, Vec); - DemandedElts.clearBit(Idx); - if (SimplifyDemandedVectorElts(Vec, DemandedElts, KnownUndef, + APInt DemandedVecElts(DemandedElts); + DemandedVecElts.clearBit(Idx); + if (SimplifyDemandedVectorElts(Vec, DemandedVecElts, KnownUndef, KnownZero, TLO, Depth + 1)) return true; @@ -1534,12 +1688,20 @@ bool TargetLowering::SimplifyDemandedVectorElts( break; } case ISD::VSELECT: { - APInt DemandedLHS(DemandedElts); - APInt DemandedRHS(DemandedElts); - - // TODO - add support for constant vselect masks. + // Try to transform the select condition based on the current demanded + // elements. + // TODO: If a condition element is undef, we can choose from one arm of the + // select (and if one arm is undef, then we can propagate that to the + // result). + // TODO - add support for constant vselect masks (see IR version of this). + APInt UnusedUndef, UnusedZero; + if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, UnusedUndef, + UnusedZero, TLO, Depth + 1)) + return true; // See if we can simplify either vselect operand. + APInt DemandedLHS(DemandedElts); + APInt DemandedRHS(DemandedElts); APInt UndefLHS, ZeroLHS; APInt UndefRHS, ZeroRHS; if (SimplifyDemandedVectorElts(Op.getOperand(1), DemandedLHS, UndefLHS, @@ -1624,8 +1786,35 @@ bool TargetLowering::SimplifyDemandedVectorElts( } break; } + case ISD::SIGN_EXTEND_VECTOR_INREG: + case ISD::ZERO_EXTEND_VECTOR_INREG: { + APInt SrcUndef, SrcZero; + SDValue Src = Op.getOperand(0); + unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); + APInt DemandedSrcElts = DemandedElts.zextOrSelf(NumSrcElts); + if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, + SrcZero, TLO, Depth + 1)) + return true; + KnownZero = SrcZero.zextOrTrunc(NumElts); + KnownUndef = SrcUndef.zextOrTrunc(NumElts); + + if (Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) { + // zext(undef) upper bits are guaranteed to be zero. + if (DemandedElts.isSubsetOf(KnownUndef)) + return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT)); + KnownUndef.clearAllBits(); + } + break; + } + case ISD::OR: + case ISD::XOR: case ISD::ADD: - case ISD::SUB: { + case ISD::SUB: + case ISD::FADD: + case ISD::FSUB: + case ISD::FMUL: + case ISD::FDIV: + case ISD::FREM: { APInt SrcUndef, SrcZero; if (SimplifyDemandedVectorElts(Op.getOperand(1), DemandedElts, SrcUndef, SrcZero, TLO, Depth + 1)) @@ -1637,21 +1826,58 @@ bool TargetLowering::SimplifyDemandedVectorElts( KnownUndef &= SrcUndef; break; } + case ISD::AND: { + APInt SrcUndef, SrcZero; + if (SimplifyDemandedVectorElts(Op.getOperand(1), DemandedElts, SrcUndef, + SrcZero, TLO, Depth + 1)) + return true; + if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, KnownUndef, + KnownZero, TLO, Depth + 1)) + return true; + + // If either side has a zero element, then the result element is zero, even + // if the other is an UNDEF. + KnownZero |= SrcZero; + KnownUndef &= SrcUndef; + KnownUndef &= ~KnownZero; + break; + } case ISD::TRUNCATE: + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, KnownUndef, KnownZero, TLO, Depth + 1)) return true; + + if (Op.getOpcode() == ISD::ZERO_EXTEND) { + // zext(undef) upper bits are guaranteed to be zero. + if (DemandedElts.isSubsetOf(KnownUndef)) + return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT)); + KnownUndef.clearAllBits(); + } break; default: { - if (Op.getOpcode() >= ISD::BUILTIN_OP_END) + if (Op.getOpcode() >= ISD::BUILTIN_OP_END) { if (SimplifyDemandedVectorEltsForTargetNode(Op, DemandedElts, KnownUndef, KnownZero, TLO, Depth)) return true; + } else { + KnownBits Known; + APInt DemandedBits = APInt::getAllOnesValue(EltSizeInBits); + if (SimplifyDemandedBits(Op, DemandedBits, DemandedEltMask, Known, TLO, + Depth, AssumeSingleUse)) + return true; + } break; } } - assert((KnownUndef & KnownZero) == 0 && "Elements flagged as undef AND zero"); + + // Constant fold all undef cases. + // TODO: Handle zero cases as well. + if (DemandedElts.isSubsetOf(KnownUndef)) + return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT)); + return false; } @@ -1711,6 +1937,32 @@ bool TargetLowering::SimplifyDemandedVectorEltsForTargetNode( return false; } +bool TargetLowering::SimplifyDemandedBitsForTargetNode( + SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, + KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const { + assert((Op.getOpcode() >= ISD::BUILTIN_OP_END || + Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN || + Op.getOpcode() == ISD::INTRINSIC_W_CHAIN || + Op.getOpcode() == ISD::INTRINSIC_VOID) && + "Should use SimplifyDemandedBits if you don't know whether Op" + " is a target node!"); + computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth); + return false; +} + +bool TargetLowering::isKnownNeverNaNForTargetNode(SDValue Op, + const SelectionDAG &DAG, + bool SNaN, + unsigned Depth) const { + assert((Op.getOpcode() >= ISD::BUILTIN_OP_END || + Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN || + Op.getOpcode() == ISD::INTRINSIC_W_CHAIN || + Op.getOpcode() == ISD::INTRINSIC_VOID) && + "Should use isKnownNeverNaN if you don't know whether Op" + " is a target node!"); + return false; +} + // FIXME: Ideally, this would use ISD::isConstantSplatVector(), but that must // work with truncating build vectors and vectors with elements of less than // 8 bits. @@ -1901,10 +2153,24 @@ SDValue TargetLowering::optimizeSetCCOfSignedTruncationCheck( } else return SDValue(); - const APInt &I01 = C01->getAPIntValue(); - // Both of them must be power-of-two, and the constant from setcc is bigger. - if (!(I1.ugt(I01) && I1.isPowerOf2() && I01.isPowerOf2())) - return SDValue(); + APInt I01 = C01->getAPIntValue(); + + auto checkConstants = [&I1, &I01]() -> bool { + // Both of them must be power-of-two, and the constant from setcc is bigger. + return I1.ugt(I01) && I1.isPowerOf2() && I01.isPowerOf2(); + }; + + if (checkConstants()) { + // Great, e.g. got icmp ult i16 (add i16 %x, 128), 256 + } else { + // What if we invert constants? (and the target predicate) + I1.negate(); + I01.negate(); + NewCond = getSetCCInverse(NewCond, /*isInteger=*/true); + if (!checkConstants()) + return SDValue(); + // Great, e.g. got icmp uge i16 (add i16 %x, -128), -256 + } // They are power-of-two, so which bit is set? const unsigned KeptBits = I1.logBase2(); @@ -2141,7 +2407,8 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, } if (bestWidth) { EVT newVT = EVT::getIntegerVT(*DAG.getContext(), bestWidth); - if (newVT.isRound()) { + if (newVT.isRound() && + shouldReduceLoadWidth(Lod, ISD::NON_EXTLOAD, newVT)) { EVT PtrType = Lod->getOperand(1).getValueType(); SDValue Ptr = Lod->getBasePtr(); if (bestOffset != 0) @@ -2819,8 +3086,11 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, /// Returns true (and the GlobalValue and the offset) if the node is a /// GlobalAddress + offset. -bool TargetLowering::isGAPlusOffset(SDNode *N, const GlobalValue *&GA, +bool TargetLowering::isGAPlusOffset(SDNode *WN, const GlobalValue *&GA, int64_t &Offset) const { + + SDNode *N = unwrapAddress(SDValue(WN, 0)).getNode(); + if (auto *GASD = dyn_cast<GlobalAddressSDNode>(N)) { GA = GASD->getGlobal(); Offset += GASD->getOffset(); @@ -3419,34 +3689,63 @@ void TargetLowering::ComputeConstraintToUse(AsmOperandInfo &OpInfo, /// Given an exact SDIV by a constant, create a multiplication /// with the multiplicative inverse of the constant. -static SDValue BuildExactSDIV(const TargetLowering &TLI, SDValue Op1, APInt d, +static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N, const SDLoc &dl, SelectionDAG &DAG, SmallVectorImpl<SDNode *> &Created) { - assert(d != 0 && "Division by zero!"); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + EVT VT = N->getValueType(0); + EVT SVT = VT.getScalarType(); + EVT ShVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout()); + EVT ShSVT = ShVT.getScalarType(); + + bool UseSRA = false; + SmallVector<SDValue, 16> Shifts, Factors; + + auto BuildSDIVPattern = [&](ConstantSDNode *C) { + if (C->isNullValue()) + return false; + APInt Divisor = C->getAPIntValue(); + unsigned Shift = Divisor.countTrailingZeros(); + if (Shift) { + Divisor.ashrInPlace(Shift); + UseSRA = true; + } + // Calculate the multiplicative inverse, using Newton's method. + APInt t; + APInt Factor = Divisor; + while ((t = Divisor * Factor) != 1) + Factor *= APInt(Divisor.getBitWidth(), 2) - t; + Shifts.push_back(DAG.getConstant(Shift, dl, ShSVT)); + Factors.push_back(DAG.getConstant(Factor, dl, SVT)); + return true; + }; + + // Collect all magic values from the build vector. + if (!ISD::matchUnaryPredicate(Op1, BuildSDIVPattern)) + return SDValue(); + + SDValue Shift, Factor; + if (VT.isVector()) { + Shift = DAG.getBuildVector(ShVT, dl, Shifts); + Factor = DAG.getBuildVector(VT, dl, Factors); + } else { + Shift = Shifts[0]; + Factor = Factors[0]; + } + + SDValue Res = Op0; // Shift the value upfront if it is even, so the LSB is one. - unsigned ShAmt = d.countTrailingZeros(); - if (ShAmt) { + if (UseSRA) { // TODO: For UDIV use SRL instead of SRA. - SDValue Amt = - DAG.getConstant(ShAmt, dl, TLI.getShiftAmountTy(Op1.getValueType(), - DAG.getDataLayout())); SDNodeFlags Flags; Flags.setExact(true); - Op1 = DAG.getNode(ISD::SRA, dl, Op1.getValueType(), Op1, Amt, Flags); - Created.push_back(Op1.getNode()); - d.ashrInPlace(ShAmt); + Res = DAG.getNode(ISD::SRA, dl, VT, Res, Shift, Flags); + Created.push_back(Res.getNode()); } - // Calculate the multiplicative inverse, using Newton's method. - APInt t, xn = d; - while ((t = d*xn) != 1) - xn *= APInt(d.getBitWidth(), 2) - t; - - SDValue Op2 = DAG.getConstant(xn, dl, Op1.getValueType()); - SDValue Mul = DAG.getNode(ISD::MUL, dl, Op1.getValueType(), Op1, Op2); - Created.push_back(Mul.getNode()); - return Mul; + return DAG.getNode(ISD::MUL, dl, VT, Res, Factor); } SDValue TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor, @@ -3463,11 +3762,15 @@ SDValue TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor, /// return a DAG expression to select that will generate the same value by /// multiplying by a magic number. /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide". -SDValue TargetLowering::BuildSDIV(SDNode *N, const APInt &Divisor, - SelectionDAG &DAG, bool IsAfterLegalization, +SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG, + bool IsAfterLegalization, SmallVectorImpl<SDNode *> &Created) const { - EVT VT = N->getValueType(0); SDLoc dl(N); + EVT VT = N->getValueType(0); + EVT SVT = VT.getScalarType(); + EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout()); + EVT ShSVT = ShVT.getScalarType(); + unsigned EltBits = VT.getScalarSizeInBits(); // Check to see if we can do this. // FIXME: We should be more aggressive here. @@ -3476,50 +3779,90 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, const APInt &Divisor, // If the sdiv has an 'exact' bit we can use a simpler lowering. if (N->getFlags().hasExact()) - return BuildExactSDIV(*this, N->getOperand(0), Divisor, dl, DAG, Created); + return BuildExactSDIV(*this, N, dl, DAG, Created); + + SmallVector<SDValue, 16> MagicFactors, Factors, Shifts, ShiftMasks; + + auto BuildSDIVPattern = [&](ConstantSDNode *C) { + if (C->isNullValue()) + return false; + + const APInt &Divisor = C->getAPIntValue(); + APInt::ms magics = Divisor.magic(); + int NumeratorFactor = 0; + int ShiftMask = -1; + + if (Divisor.isOneValue() || Divisor.isAllOnesValue()) { + // If d is +1/-1, we just multiply the numerator by +1/-1. + NumeratorFactor = Divisor.getSExtValue(); + magics.m = 0; + magics.s = 0; + ShiftMask = 0; + } else if (Divisor.isStrictlyPositive() && magics.m.isNegative()) { + // If d > 0 and m < 0, add the numerator. + NumeratorFactor = 1; + } else if (Divisor.isNegative() && magics.m.isStrictlyPositive()) { + // If d < 0 and m > 0, subtract the numerator. + NumeratorFactor = -1; + } + + MagicFactors.push_back(DAG.getConstant(magics.m, dl, SVT)); + Factors.push_back(DAG.getConstant(NumeratorFactor, dl, SVT)); + Shifts.push_back(DAG.getConstant(magics.s, dl, ShSVT)); + ShiftMasks.push_back(DAG.getConstant(ShiftMask, dl, SVT)); + return true; + }; + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // Collect the shifts / magic values from each element. + if (!ISD::matchUnaryPredicate(N1, BuildSDIVPattern)) + return SDValue(); - APInt::ms magics = Divisor.magic(); + SDValue MagicFactor, Factor, Shift, ShiftMask; + if (VT.isVector()) { + MagicFactor = DAG.getBuildVector(VT, dl, MagicFactors); + Factor = DAG.getBuildVector(VT, dl, Factors); + Shift = DAG.getBuildVector(ShVT, dl, Shifts); + ShiftMask = DAG.getBuildVector(VT, dl, ShiftMasks); + } else { + MagicFactor = MagicFactors[0]; + Factor = Factors[0]; + Shift = Shifts[0]; + ShiftMask = ShiftMasks[0]; + } - // Multiply the numerator (operand 0) by the magic value - // FIXME: We should support doing a MUL in a wider type + // Multiply the numerator (operand 0) by the magic value. + // FIXME: We should support doing a MUL in a wider type. SDValue Q; - if (IsAfterLegalization ? isOperationLegal(ISD::MULHS, VT) : - isOperationLegalOrCustom(ISD::MULHS, VT)) - Q = DAG.getNode(ISD::MULHS, dl, VT, N->getOperand(0), - DAG.getConstant(magics.m, dl, VT)); - else if (IsAfterLegalization ? isOperationLegal(ISD::SMUL_LOHI, VT) : - isOperationLegalOrCustom(ISD::SMUL_LOHI, VT)) - Q = SDValue(DAG.getNode(ISD::SMUL_LOHI, dl, DAG.getVTList(VT, VT), - N->getOperand(0), - DAG.getConstant(magics.m, dl, VT)).getNode(), 1); - else - return SDValue(); // No mulhs or equvialent + if (IsAfterLegalization ? isOperationLegal(ISD::MULHS, VT) + : isOperationLegalOrCustom(ISD::MULHS, VT)) + Q = DAG.getNode(ISD::MULHS, dl, VT, N0, MagicFactor); + else if (IsAfterLegalization ? isOperationLegal(ISD::SMUL_LOHI, VT) + : isOperationLegalOrCustom(ISD::SMUL_LOHI, VT)) { + SDValue LoHi = + DAG.getNode(ISD::SMUL_LOHI, dl, DAG.getVTList(VT, VT), N0, MagicFactor); + Q = SDValue(LoHi.getNode(), 1); + } else + return SDValue(); // No mulhs or equivalent. + Created.push_back(Q.getNode()); + // (Optionally) Add/subtract the numerator using Factor. + Factor = DAG.getNode(ISD::MUL, dl, VT, N0, Factor); + Created.push_back(Factor.getNode()); + Q = DAG.getNode(ISD::ADD, dl, VT, Q, Factor); Created.push_back(Q.getNode()); - // If d > 0 and m < 0, add the numerator - if (Divisor.isStrictlyPositive() && magics.m.isNegative()) { - Q = DAG.getNode(ISD::ADD, dl, VT, Q, N->getOperand(0)); - Created.push_back(Q.getNode()); - } - // If d < 0 and m > 0, subtract the numerator. - if (Divisor.isNegative() && magics.m.isStrictlyPositive()) { - Q = DAG.getNode(ISD::SUB, dl, VT, Q, N->getOperand(0)); - Created.push_back(Q.getNode()); - } - auto &DL = DAG.getDataLayout(); - // Shift right algebraic if shift value is nonzero - if (magics.s > 0) { - Q = DAG.getNode( - ISD::SRA, dl, VT, Q, - DAG.getConstant(magics.s, dl, getShiftAmountTy(Q.getValueType(), DL))); - Created.push_back(Q.getNode()); - } - // Extract the sign bit and add it to the quotient - SDValue T = - DAG.getNode(ISD::SRL, dl, VT, Q, - DAG.getConstant(VT.getScalarSizeInBits() - 1, dl, - getShiftAmountTy(Q.getValueType(), DL))); + // Shift right algebraic by shift value. + Q = DAG.getNode(ISD::SRA, dl, VT, Q, Shift); + Created.push_back(Q.getNode()); + + // Extract the sign bit, mask it and add it to the quotient. + SDValue SignShift = DAG.getConstant(EltBits - 1, dl, ShVT); + SDValue T = DAG.getNode(ISD::SRL, dl, VT, Q, SignShift); + Created.push_back(T.getNode()); + T = DAG.getNode(ISD::AND, dl, VT, T, ShiftMask); Created.push_back(T.getNode()); return DAG.getNode(ISD::ADD, dl, VT, Q, T); } @@ -3528,72 +3871,133 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, const APInt &Divisor, /// return a DAG expression to select that will generate the same value by /// multiplying by a magic number. /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide". -SDValue TargetLowering::BuildUDIV(SDNode *N, const APInt &Divisor, - SelectionDAG &DAG, bool IsAfterLegalization, +SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG, + bool IsAfterLegalization, SmallVectorImpl<SDNode *> &Created) const { - EVT VT = N->getValueType(0); SDLoc dl(N); - auto &DL = DAG.getDataLayout(); + EVT VT = N->getValueType(0); + EVT SVT = VT.getScalarType(); + EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout()); + EVT ShSVT = ShVT.getScalarType(); + unsigned EltBits = VT.getScalarSizeInBits(); // Check to see if we can do this. // FIXME: We should be more aggressive here. if (!isTypeLegal(VT)) return SDValue(); - // FIXME: We should use a narrower constant when the upper - // bits are known to be zero. - APInt::mu magics = Divisor.magicu(); + bool UseNPQ = false; + SmallVector<SDValue, 16> PreShifts, PostShifts, MagicFactors, NPQFactors; - SDValue Q = N->getOperand(0); + auto BuildUDIVPattern = [&](ConstantSDNode *C) { + if (C->isNullValue()) + return false; + // FIXME: We should use a narrower constant when the upper + // bits are known to be zero. + APInt Divisor = C->getAPIntValue(); + APInt::mu magics = Divisor.magicu(); + unsigned PreShift = 0, PostShift = 0; + + // If the divisor is even, we can avoid using the expensive fixup by + // shifting the divided value upfront. + if (magics.a != 0 && !Divisor[0]) { + PreShift = Divisor.countTrailingZeros(); + // Get magic number for the shifted divisor. + magics = Divisor.lshr(PreShift).magicu(PreShift); + assert(magics.a == 0 && "Should use cheap fixup now"); + } - // If the divisor is even, we can avoid using the expensive fixup by shifting - // the divided value upfront. - if (magics.a != 0 && !Divisor[0]) { - unsigned Shift = Divisor.countTrailingZeros(); - Q = DAG.getNode( - ISD::SRL, dl, VT, Q, - DAG.getConstant(Shift, dl, getShiftAmountTy(Q.getValueType(), DL))); - Created.push_back(Q.getNode()); + APInt Magic = magics.m; + + unsigned SelNPQ; + if (magics.a == 0 || Divisor.isOneValue()) { + assert(magics.s < Divisor.getBitWidth() && + "We shouldn't generate an undefined shift!"); + PostShift = magics.s; + SelNPQ = false; + } else { + PostShift = magics.s - 1; + SelNPQ = true; + } + + PreShifts.push_back(DAG.getConstant(PreShift, dl, ShSVT)); + MagicFactors.push_back(DAG.getConstant(Magic, dl, SVT)); + NPQFactors.push_back( + DAG.getConstant(SelNPQ ? APInt::getOneBitSet(EltBits, EltBits - 1) + : APInt::getNullValue(EltBits), + dl, SVT)); + PostShifts.push_back(DAG.getConstant(PostShift, dl, ShSVT)); + UseNPQ |= SelNPQ; + return true; + }; - // Get magic number for the shifted divisor. - magics = Divisor.lshr(Shift).magicu(Shift); - assert(magics.a == 0 && "Should use cheap fixup now"); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // Collect the shifts/magic values from each element. + if (!ISD::matchUnaryPredicate(N1, BuildUDIVPattern)) + return SDValue(); + + SDValue PreShift, PostShift, MagicFactor, NPQFactor; + if (VT.isVector()) { + PreShift = DAG.getBuildVector(ShVT, dl, PreShifts); + MagicFactor = DAG.getBuildVector(VT, dl, MagicFactors); + NPQFactor = DAG.getBuildVector(VT, dl, NPQFactors); + PostShift = DAG.getBuildVector(ShVT, dl, PostShifts); + } else { + PreShift = PreShifts[0]; + MagicFactor = MagicFactors[0]; + PostShift = PostShifts[0]; } - // Multiply the numerator (operand 0) by the magic value - // FIXME: We should support doing a MUL in a wider type - if (IsAfterLegalization ? isOperationLegal(ISD::MULHU, VT) : - isOperationLegalOrCustom(ISD::MULHU, VT)) - Q = DAG.getNode(ISD::MULHU, dl, VT, Q, DAG.getConstant(magics.m, dl, VT)); - else if (IsAfterLegalization ? isOperationLegal(ISD::UMUL_LOHI, VT) : - isOperationLegalOrCustom(ISD::UMUL_LOHI, VT)) - Q = SDValue(DAG.getNode(ISD::UMUL_LOHI, dl, DAG.getVTList(VT, VT), Q, - DAG.getConstant(magics.m, dl, VT)).getNode(), 1); - else - return SDValue(); // No mulhu or equivalent + SDValue Q = N0; + Q = DAG.getNode(ISD::SRL, dl, VT, Q, PreShift); + Created.push_back(Q.getNode()); + + // FIXME: We should support doing a MUL in a wider type. + auto GetMULHU = [&](SDValue X, SDValue Y) { + if (IsAfterLegalization ? isOperationLegal(ISD::MULHU, VT) + : isOperationLegalOrCustom(ISD::MULHU, VT)) + return DAG.getNode(ISD::MULHU, dl, VT, X, Y); + if (IsAfterLegalization ? isOperationLegal(ISD::UMUL_LOHI, VT) + : isOperationLegalOrCustom(ISD::UMUL_LOHI, VT)) { + SDValue LoHi = + DAG.getNode(ISD::UMUL_LOHI, dl, DAG.getVTList(VT, VT), X, Y); + return SDValue(LoHi.getNode(), 1); + } + return SDValue(); // No mulhu or equivalent + }; + + // Multiply the numerator (operand 0) by the magic value. + Q = GetMULHU(Q, MagicFactor); + if (!Q) + return SDValue(); Created.push_back(Q.getNode()); - if (magics.a == 0) { - assert(magics.s < Divisor.getBitWidth() && - "We shouldn't generate an undefined shift!"); - return DAG.getNode( - ISD::SRL, dl, VT, Q, - DAG.getConstant(magics.s, dl, getShiftAmountTy(Q.getValueType(), DL))); - } else { - SDValue NPQ = DAG.getNode(ISD::SUB, dl, VT, N->getOperand(0), Q); - Created.push_back(NPQ.getNode()); - NPQ = DAG.getNode( - ISD::SRL, dl, VT, NPQ, - DAG.getConstant(1, dl, getShiftAmountTy(NPQ.getValueType(), DL))); + if (UseNPQ) { + SDValue NPQ = DAG.getNode(ISD::SUB, dl, VT, N0, Q); Created.push_back(NPQ.getNode()); - NPQ = DAG.getNode(ISD::ADD, dl, VT, NPQ, Q); + + // For vectors we might have a mix of non-NPQ/NPQ paths, so use + // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero. + if (VT.isVector()) + NPQ = GetMULHU(NPQ, NPQFactor); + else + NPQ = DAG.getNode(ISD::SRL, dl, VT, NPQ, DAG.getConstant(1, dl, ShVT)); + Created.push_back(NPQ.getNode()); - return DAG.getNode( - ISD::SRL, dl, VT, NPQ, - DAG.getConstant(magics.s - 1, dl, - getShiftAmountTy(NPQ.getValueType(), DL))); + + Q = DAG.getNode(ISD::ADD, dl, VT, NPQ, Q); + Created.push_back(Q.getNode()); } + + Q = DAG.getNode(ISD::SRL, dl, VT, Q, PostShift); + Created.push_back(Q.getNode()); + + SDValue One = DAG.getConstant(1, dl, VT); + SDValue IsOne = DAG.getSetCC(dl, VT, N1, One, ISD::SETEQ); + return DAG.getSelect(dl, VT, IsOne, N0, Q); } bool TargetLowering:: @@ -3750,8 +4154,17 @@ bool TargetLowering::expandMUL_LOHI(unsigned Opcode, EVT VT, SDLoc dl, if (!MakeMUL_LOHI(LH, RL, Lo, Hi, false)) return false; - Next = DAG.getNode(ISD::ADDC, dl, DAG.getVTList(VT, MVT::Glue), Next, - Merge(Lo, Hi)); + SDValue Zero = DAG.getConstant(0, dl, HiLoVT); + EVT BoolType = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + + bool UseGlue = (isOperationLegalOrCustom(ISD::ADDC, VT) && + isOperationLegalOrCustom(ISD::ADDE, VT)); + if (UseGlue) + Next = DAG.getNode(ISD::ADDC, dl, DAG.getVTList(VT, MVT::Glue), Next, + Merge(Lo, Hi)); + else + Next = DAG.getNode(ISD::ADDCARRY, dl, DAG.getVTList(VT, BoolType), Next, + Merge(Lo, Hi), DAG.getConstant(0, dl, BoolType)); SDValue Carry = Next.getValue(1); Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next)); @@ -3760,9 +4173,13 @@ bool TargetLowering::expandMUL_LOHI(unsigned Opcode, EVT VT, SDLoc dl, if (!MakeMUL_LOHI(LH, RH, Lo, Hi, Opcode == ISD::SMUL_LOHI)) return false; - SDValue Zero = DAG.getConstant(0, dl, HiLoVT); - Hi = DAG.getNode(ISD::ADDE, dl, DAG.getVTList(HiLoVT, MVT::Glue), Hi, Zero, - Carry); + if (UseGlue) + Hi = DAG.getNode(ISD::ADDE, dl, DAG.getVTList(HiLoVT, MVT::Glue), Hi, Zero, + Carry); + else + Hi = DAG.getNode(ISD::ADDCARRY, dl, DAG.getVTList(HiLoVT, BoolType), Hi, + Zero, Carry); + Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi)); if (Opcode == ISD::SMUL_LOHI) { @@ -3797,66 +4214,525 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT, return Ok; } +bool TargetLowering::expandFunnelShift(SDNode *Node, SDValue &Result, + SelectionDAG &DAG) const { + EVT VT = Node->getValueType(0); + + if (VT.isVector() && (!isOperationLegalOrCustom(ISD::SHL, VT) || + !isOperationLegalOrCustom(ISD::SRL, VT) || + !isOperationLegalOrCustom(ISD::SUB, VT) || + !isOperationLegalOrCustomOrPromote(ISD::OR, VT))) + return false; + + // fshl: (X << (Z % BW)) | (Y >> (BW - (Z % BW))) + // fshr: (X << (BW - (Z % BW))) | (Y >> (Z % BW)) + SDValue X = Node->getOperand(0); + SDValue Y = Node->getOperand(1); + SDValue Z = Node->getOperand(2); + + unsigned EltSizeInBits = VT.getScalarSizeInBits(); + bool IsFSHL = Node->getOpcode() == ISD::FSHL; + SDLoc DL(SDValue(Node, 0)); + + EVT ShVT = Z.getValueType(); + SDValue BitWidthC = DAG.getConstant(EltSizeInBits, DL, ShVT); + SDValue Zero = DAG.getConstant(0, DL, ShVT); + + SDValue ShAmt; + if (isPowerOf2_32(EltSizeInBits)) { + SDValue Mask = DAG.getConstant(EltSizeInBits - 1, DL, ShVT); + ShAmt = DAG.getNode(ISD::AND, DL, ShVT, Z, Mask); + } else { + ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC); + } + + SDValue InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthC, ShAmt); + SDValue ShX = DAG.getNode(ISD::SHL, DL, VT, X, IsFSHL ? ShAmt : InvShAmt); + SDValue ShY = DAG.getNode(ISD::SRL, DL, VT, Y, IsFSHL ? InvShAmt : ShAmt); + SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShX, ShY); + + // If (Z % BW == 0), then the opposite direction shift is shift-by-bitwidth, + // and that is undefined. We must compare and select to avoid UB. + EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), ShVT); + + // For fshl, 0-shift returns the 1st arg (X). + // For fshr, 0-shift returns the 2nd arg (Y). + SDValue IsZeroShift = DAG.getSetCC(DL, CCVT, ShAmt, Zero, ISD::SETEQ); + Result = DAG.getSelect(DL, VT, IsZeroShift, IsFSHL ? X : Y, Or); + return true; +} + +// TODO: Merge with expandFunnelShift. +bool TargetLowering::expandROT(SDNode *Node, SDValue &Result, + SelectionDAG &DAG) const { + EVT VT = Node->getValueType(0); + unsigned EltSizeInBits = VT.getScalarSizeInBits(); + bool IsLeft = Node->getOpcode() == ISD::ROTL; + SDValue Op0 = Node->getOperand(0); + SDValue Op1 = Node->getOperand(1); + SDLoc DL(SDValue(Node, 0)); + + EVT ShVT = Op1.getValueType(); + SDValue BitWidthC = DAG.getConstant(EltSizeInBits, DL, ShVT); + + // If a rotate in the other direction is legal, use it. + unsigned RevRot = IsLeft ? ISD::ROTR : ISD::ROTL; + if (isOperationLegal(RevRot, VT)) { + SDValue Sub = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthC, Op1); + Result = DAG.getNode(RevRot, DL, VT, Op0, Sub); + return true; + } + + if (VT.isVector() && (!isOperationLegalOrCustom(ISD::SHL, VT) || + !isOperationLegalOrCustom(ISD::SRL, VT) || + !isOperationLegalOrCustom(ISD::SUB, VT) || + !isOperationLegalOrCustomOrPromote(ISD::OR, VT) || + !isOperationLegalOrCustomOrPromote(ISD::AND, VT))) + return false; + + // Otherwise, + // (rotl x, c) -> (or (shl x, (and c, w-1)), (srl x, (and w-c, w-1))) + // (rotr x, c) -> (or (srl x, (and c, w-1)), (shl x, (and w-c, w-1))) + // + assert(isPowerOf2_32(EltSizeInBits) && EltSizeInBits > 1 && + "Expecting the type bitwidth to be a power of 2"); + unsigned ShOpc = IsLeft ? ISD::SHL : ISD::SRL; + unsigned HsOpc = IsLeft ? ISD::SRL : ISD::SHL; + SDValue BitWidthMinusOneC = DAG.getConstant(EltSizeInBits - 1, DL, ShVT); + SDValue NegOp1 = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthC, Op1); + SDValue And0 = DAG.getNode(ISD::AND, DL, ShVT, Op1, BitWidthMinusOneC); + SDValue And1 = DAG.getNode(ISD::AND, DL, ShVT, NegOp1, BitWidthMinusOneC); + Result = DAG.getNode(ISD::OR, DL, VT, DAG.getNode(ShOpc, DL, VT, Op0, And0), + DAG.getNode(HsOpc, DL, VT, Op0, And1)); + return true; +} + bool TargetLowering::expandFP_TO_SINT(SDNode *Node, SDValue &Result, SelectionDAG &DAG) const { - EVT VT = Node->getOperand(0).getValueType(); - EVT NVT = Node->getValueType(0); + SDValue Src = Node->getOperand(0); + EVT SrcVT = Src.getValueType(); + EVT DstVT = Node->getValueType(0); SDLoc dl(SDValue(Node, 0)); // FIXME: Only f32 to i64 conversions are supported. - if (VT != MVT::f32 || NVT != MVT::i64) + if (SrcVT != MVT::f32 || DstVT != MVT::i64) return false; // Expand f32 -> i64 conversion // This algorithm comes from compiler-rt's implementation of fixsfdi: // https://github.com/llvm-mirror/compiler-rt/blob/master/lib/builtins/fixsfdi.c - EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), - VT.getSizeInBits()); + unsigned SrcEltBits = SrcVT.getScalarSizeInBits(); + EVT IntVT = SrcVT.changeTypeToInteger(); + EVT IntShVT = getShiftAmountTy(IntVT, DAG.getDataLayout()); + SDValue ExponentMask = DAG.getConstant(0x7F800000, dl, IntVT); SDValue ExponentLoBit = DAG.getConstant(23, dl, IntVT); SDValue Bias = DAG.getConstant(127, dl, IntVT); - SDValue SignMask = DAG.getConstant(APInt::getSignMask(VT.getSizeInBits()), dl, - IntVT); - SDValue SignLowBit = DAG.getConstant(VT.getSizeInBits() - 1, dl, IntVT); + SDValue SignMask = DAG.getConstant(APInt::getSignMask(SrcEltBits), dl, IntVT); + SDValue SignLowBit = DAG.getConstant(SrcEltBits - 1, dl, IntVT); SDValue MantissaMask = DAG.getConstant(0x007FFFFF, dl, IntVT); - SDValue Bits = DAG.getNode(ISD::BITCAST, dl, IntVT, Node->getOperand(0)); + SDValue Bits = DAG.getNode(ISD::BITCAST, dl, IntVT, Src); - auto &DL = DAG.getDataLayout(); SDValue ExponentBits = DAG.getNode( ISD::SRL, dl, IntVT, DAG.getNode(ISD::AND, dl, IntVT, Bits, ExponentMask), - DAG.getZExtOrTrunc(ExponentLoBit, dl, getShiftAmountTy(IntVT, DL))); + DAG.getZExtOrTrunc(ExponentLoBit, dl, IntShVT)); SDValue Exponent = DAG.getNode(ISD::SUB, dl, IntVT, ExponentBits, Bias); - SDValue Sign = DAG.getNode( - ISD::SRA, dl, IntVT, DAG.getNode(ISD::AND, dl, IntVT, Bits, SignMask), - DAG.getZExtOrTrunc(SignLowBit, dl, getShiftAmountTy(IntVT, DL))); - Sign = DAG.getSExtOrTrunc(Sign, dl, NVT); + SDValue Sign = DAG.getNode(ISD::SRA, dl, IntVT, + DAG.getNode(ISD::AND, dl, IntVT, Bits, SignMask), + DAG.getZExtOrTrunc(SignLowBit, dl, IntShVT)); + Sign = DAG.getSExtOrTrunc(Sign, dl, DstVT); SDValue R = DAG.getNode(ISD::OR, dl, IntVT, - DAG.getNode(ISD::AND, dl, IntVT, Bits, MantissaMask), - DAG.getConstant(0x00800000, dl, IntVT)); + DAG.getNode(ISD::AND, dl, IntVT, Bits, MantissaMask), + DAG.getConstant(0x00800000, dl, IntVT)); - R = DAG.getZExtOrTrunc(R, dl, NVT); + R = DAG.getZExtOrTrunc(R, dl, DstVT); R = DAG.getSelectCC( dl, Exponent, ExponentLoBit, - DAG.getNode(ISD::SHL, dl, NVT, R, + DAG.getNode(ISD::SHL, dl, DstVT, R, DAG.getZExtOrTrunc( DAG.getNode(ISD::SUB, dl, IntVT, Exponent, ExponentLoBit), - dl, getShiftAmountTy(IntVT, DL))), - DAG.getNode(ISD::SRL, dl, NVT, R, + dl, IntShVT)), + DAG.getNode(ISD::SRL, dl, DstVT, R, DAG.getZExtOrTrunc( DAG.getNode(ISD::SUB, dl, IntVT, ExponentLoBit, Exponent), - dl, getShiftAmountTy(IntVT, DL))), + dl, IntShVT)), ISD::SETGT); - SDValue Ret = DAG.getNode(ISD::SUB, dl, NVT, - DAG.getNode(ISD::XOR, dl, NVT, R, Sign), - Sign); + SDValue Ret = DAG.getNode(ISD::SUB, dl, DstVT, + DAG.getNode(ISD::XOR, dl, DstVT, R, Sign), Sign); Result = DAG.getSelectCC(dl, Exponent, DAG.getConstant(0, dl, IntVT), - DAG.getConstant(0, dl, NVT), Ret, ISD::SETLT); + DAG.getConstant(0, dl, DstVT), Ret, ISD::SETLT); + return true; +} + +bool TargetLowering::expandFP_TO_UINT(SDNode *Node, SDValue &Result, + SelectionDAG &DAG) const { + SDLoc dl(SDValue(Node, 0)); + SDValue Src = Node->getOperand(0); + + EVT SrcVT = Src.getValueType(); + EVT DstVT = Node->getValueType(0); + EVT SetCCVT = + getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT); + + // Only expand vector types if we have the appropriate vector bit operations. + if (DstVT.isVector() && (!isOperationLegalOrCustom(ISD::FP_TO_SINT, DstVT) || + !isOperationLegalOrCustomOrPromote(ISD::XOR, SrcVT))) + return false; + + // If the maximum float value is smaller then the signed integer range, + // the destination signmask can't be represented by the float, so we can + // just use FP_TO_SINT directly. + const fltSemantics &APFSem = DAG.EVTToAPFloatSemantics(SrcVT); + APFloat APF(APFSem, APInt::getNullValue(SrcVT.getScalarSizeInBits())); + APInt SignMask = APInt::getSignMask(DstVT.getScalarSizeInBits()); + if (APFloat::opOverflow & + APF.convertFromAPInt(SignMask, false, APFloat::rmNearestTiesToEven)) { + Result = DAG.getNode(ISD::FP_TO_SINT, dl, DstVT, Src); + return true; + } + + SDValue Cst = DAG.getConstantFP(APF, dl, SrcVT); + SDValue Sel = DAG.getSetCC(dl, SetCCVT, Src, Cst, ISD::SETLT); + + bool Strict = shouldUseStrictFP_TO_INT(SrcVT, DstVT, /*IsSigned*/ false); + if (Strict) { + // Expand based on maximum range of FP_TO_SINT, if the value exceeds the + // signmask then offset (the result of which should be fully representable). + // Sel = Src < 0x8000000000000000 + // Val = select Sel, Src, Src - 0x8000000000000000 + // Ofs = select Sel, 0, 0x8000000000000000 + // Result = fp_to_sint(Val) ^ Ofs + + // TODO: Should any fast-math-flags be set for the FSUB? + SDValue Val = DAG.getSelect(dl, SrcVT, Sel, Src, + DAG.getNode(ISD::FSUB, dl, SrcVT, Src, Cst)); + SDValue Ofs = DAG.getSelect(dl, DstVT, Sel, DAG.getConstant(0, dl, DstVT), + DAG.getConstant(SignMask, dl, DstVT)); + Result = DAG.getNode(ISD::XOR, dl, DstVT, + DAG.getNode(ISD::FP_TO_SINT, dl, DstVT, Val), Ofs); + } else { + // Expand based on maximum range of FP_TO_SINT: + // True = fp_to_sint(Src) + // False = 0x8000000000000000 + fp_to_sint(Src - 0x8000000000000000) + // Result = select (Src < 0x8000000000000000), True, False + + SDValue True = DAG.getNode(ISD::FP_TO_SINT, dl, DstVT, Src); + // TODO: Should any fast-math-flags be set for the FSUB? + SDValue False = DAG.getNode(ISD::FP_TO_SINT, dl, DstVT, + DAG.getNode(ISD::FSUB, dl, SrcVT, Src, Cst)); + False = DAG.getNode(ISD::XOR, dl, DstVT, False, + DAG.getConstant(SignMask, dl, DstVT)); + Result = DAG.getSelect(dl, DstVT, Sel, True, False); + } + return true; +} + +bool TargetLowering::expandUINT_TO_FP(SDNode *Node, SDValue &Result, + SelectionDAG &DAG) const { + SDValue Src = Node->getOperand(0); + EVT SrcVT = Src.getValueType(); + EVT DstVT = Node->getValueType(0); + + if (SrcVT.getScalarType() != MVT::i64) + return false; + + SDLoc dl(SDValue(Node, 0)); + EVT ShiftVT = getShiftAmountTy(SrcVT, DAG.getDataLayout()); + + if (DstVT.getScalarType() == MVT::f32) { + // Only expand vector types if we have the appropriate vector bit + // operations. + if (SrcVT.isVector() && + (!isOperationLegalOrCustom(ISD::SRL, SrcVT) || + !isOperationLegalOrCustom(ISD::FADD, DstVT) || + !isOperationLegalOrCustom(ISD::SINT_TO_FP, SrcVT) || + !isOperationLegalOrCustomOrPromote(ISD::OR, SrcVT) || + !isOperationLegalOrCustomOrPromote(ISD::AND, SrcVT))) + return false; + + // For unsigned conversions, convert them to signed conversions using the + // algorithm from the x86_64 __floatundidf in compiler_rt. + SDValue Fast = DAG.getNode(ISD::SINT_TO_FP, dl, DstVT, Src); + + SDValue ShiftConst = DAG.getConstant(1, dl, ShiftVT); + SDValue Shr = DAG.getNode(ISD::SRL, dl, SrcVT, Src, ShiftConst); + SDValue AndConst = DAG.getConstant(1, dl, SrcVT); + SDValue And = DAG.getNode(ISD::AND, dl, SrcVT, Src, AndConst); + SDValue Or = DAG.getNode(ISD::OR, dl, SrcVT, And, Shr); + + SDValue SignCvt = DAG.getNode(ISD::SINT_TO_FP, dl, DstVT, Or); + SDValue Slow = DAG.getNode(ISD::FADD, dl, DstVT, SignCvt, SignCvt); + + // TODO: This really should be implemented using a branch rather than a + // select. We happen to get lucky and machinesink does the right + // thing most of the time. This would be a good candidate for a + // pseudo-op, or, even better, for whole-function isel. + EVT SetCCVT = + getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT); + + SDValue SignBitTest = DAG.getSetCC( + dl, SetCCVT, Src, DAG.getConstant(0, dl, SrcVT), ISD::SETLT); + Result = DAG.getSelect(dl, DstVT, SignBitTest, Slow, Fast); + return true; + } + + if (DstVT.getScalarType() == MVT::f64) { + // Only expand vector types if we have the appropriate vector bit + // operations. + if (SrcVT.isVector() && + (!isOperationLegalOrCustom(ISD::SRL, SrcVT) || + !isOperationLegalOrCustom(ISD::FADD, DstVT) || + !isOperationLegalOrCustom(ISD::FSUB, DstVT) || + !isOperationLegalOrCustomOrPromote(ISD::OR, SrcVT) || + !isOperationLegalOrCustomOrPromote(ISD::AND, SrcVT))) + return false; + + // Implementation of unsigned i64 to f64 following the algorithm in + // __floatundidf in compiler_rt. This implementation has the advantage + // of performing rounding correctly, both in the default rounding mode + // and in all alternate rounding modes. + SDValue TwoP52 = DAG.getConstant(UINT64_C(0x4330000000000000), dl, SrcVT); + SDValue TwoP84PlusTwoP52 = DAG.getConstantFP( + BitsToDouble(UINT64_C(0x4530000000100000)), dl, DstVT); + SDValue TwoP84 = DAG.getConstant(UINT64_C(0x4530000000000000), dl, SrcVT); + SDValue LoMask = DAG.getConstant(UINT64_C(0x00000000FFFFFFFF), dl, SrcVT); + SDValue HiShift = DAG.getConstant(32, dl, ShiftVT); + + SDValue Lo = DAG.getNode(ISD::AND, dl, SrcVT, Src, LoMask); + SDValue Hi = DAG.getNode(ISD::SRL, dl, SrcVT, Src, HiShift); + SDValue LoOr = DAG.getNode(ISD::OR, dl, SrcVT, Lo, TwoP52); + SDValue HiOr = DAG.getNode(ISD::OR, dl, SrcVT, Hi, TwoP84); + SDValue LoFlt = DAG.getBitcast(DstVT, LoOr); + SDValue HiFlt = DAG.getBitcast(DstVT, HiOr); + SDValue HiSub = DAG.getNode(ISD::FSUB, dl, DstVT, HiFlt, TwoP84PlusTwoP52); + Result = DAG.getNode(ISD::FADD, dl, DstVT, LoFlt, HiSub); + return true; + } + + return false; +} + +SDValue TargetLowering::expandFMINNUM_FMAXNUM(SDNode *Node, + SelectionDAG &DAG) const { + SDLoc dl(Node); + unsigned NewOp = Node->getOpcode() == ISD::FMINNUM ? + ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE; + EVT VT = Node->getValueType(0); + if (isOperationLegalOrCustom(NewOp, VT)) { + SDValue Quiet0 = Node->getOperand(0); + SDValue Quiet1 = Node->getOperand(1); + + if (!Node->getFlags().hasNoNaNs()) { + // Insert canonicalizes if it's possible we need to quiet to get correct + // sNaN behavior. + if (!DAG.isKnownNeverSNaN(Quiet0)) { + Quiet0 = DAG.getNode(ISD::FCANONICALIZE, dl, VT, Quiet0, + Node->getFlags()); + } + if (!DAG.isKnownNeverSNaN(Quiet1)) { + Quiet1 = DAG.getNode(ISD::FCANONICALIZE, dl, VT, Quiet1, + Node->getFlags()); + } + } + + return DAG.getNode(NewOp, dl, VT, Quiet0, Quiet1, Node->getFlags()); + } + + return SDValue(); +} + +bool TargetLowering::expandCTPOP(SDNode *Node, SDValue &Result, + SelectionDAG &DAG) const { + SDLoc dl(Node); + EVT VT = Node->getValueType(0); + EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout()); + SDValue Op = Node->getOperand(0); + unsigned Len = VT.getScalarSizeInBits(); + assert(VT.isInteger() && "CTPOP not implemented for this type."); + + // TODO: Add support for irregular type lengths. + if (!(Len <= 128 && Len % 8 == 0)) + return false; + + // Only expand vector types if we have the appropriate vector bit operations. + if (VT.isVector() && (!isOperationLegalOrCustom(ISD::ADD, VT) || + !isOperationLegalOrCustom(ISD::SUB, VT) || + !isOperationLegalOrCustom(ISD::SRL, VT) || + (Len != 8 && !isOperationLegalOrCustom(ISD::MUL, VT)) || + !isOperationLegalOrCustomOrPromote(ISD::AND, VT))) + return false; + + // This is the "best" algorithm from + // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel + SDValue Mask55 = + DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), dl, VT); + SDValue Mask33 = + DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), dl, VT); + SDValue Mask0F = + DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), dl, VT); + SDValue Mask01 = + DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x01)), dl, VT); + + // v = v - ((v >> 1) & 0x55555555...) + Op = DAG.getNode(ISD::SUB, dl, VT, Op, + DAG.getNode(ISD::AND, dl, VT, + DAG.getNode(ISD::SRL, dl, VT, Op, + DAG.getConstant(1, dl, ShVT)), + Mask55)); + // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...) + Op = DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::AND, dl, VT, Op, Mask33), + DAG.getNode(ISD::AND, dl, VT, + DAG.getNode(ISD::SRL, dl, VT, Op, + DAG.getConstant(2, dl, ShVT)), + Mask33)); + // v = (v + (v >> 4)) & 0x0F0F0F0F... + Op = DAG.getNode(ISD::AND, dl, VT, + DAG.getNode(ISD::ADD, dl, VT, Op, + DAG.getNode(ISD::SRL, dl, VT, Op, + DAG.getConstant(4, dl, ShVT))), + Mask0F); + // v = (v * 0x01010101...) >> (Len - 8) + if (Len > 8) + Op = + DAG.getNode(ISD::SRL, dl, VT, DAG.getNode(ISD::MUL, dl, VT, Op, Mask01), + DAG.getConstant(Len - 8, dl, ShVT)); + + Result = Op; + return true; +} + +bool TargetLowering::expandCTLZ(SDNode *Node, SDValue &Result, + SelectionDAG &DAG) const { + SDLoc dl(Node); + EVT VT = Node->getValueType(0); + EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout()); + SDValue Op = Node->getOperand(0); + unsigned NumBitsPerElt = VT.getScalarSizeInBits(); + + // If the non-ZERO_UNDEF version is supported we can use that instead. + if (Node->getOpcode() == ISD::CTLZ_ZERO_UNDEF && + isOperationLegalOrCustom(ISD::CTLZ, VT)) { + Result = DAG.getNode(ISD::CTLZ, dl, VT, Op); + return true; + } + + // If the ZERO_UNDEF version is supported use that and handle the zero case. + if (isOperationLegalOrCustom(ISD::CTLZ_ZERO_UNDEF, VT)) { + EVT SetCCVT = + getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue CTLZ = DAG.getNode(ISD::CTLZ_ZERO_UNDEF, dl, VT, Op); + SDValue Zero = DAG.getConstant(0, dl, VT); + SDValue SrcIsZero = DAG.getSetCC(dl, SetCCVT, Op, Zero, ISD::SETEQ); + Result = DAG.getNode(ISD::SELECT, dl, VT, SrcIsZero, + DAG.getConstant(NumBitsPerElt, dl, VT), CTLZ); + return true; + } + + // Only expand vector types if we have the appropriate vector bit operations. + if (VT.isVector() && (!isPowerOf2_32(NumBitsPerElt) || + !isOperationLegalOrCustom(ISD::CTPOP, VT) || + !isOperationLegalOrCustom(ISD::SRL, VT) || + !isOperationLegalOrCustomOrPromote(ISD::OR, VT))) + return false; + + // for now, we do this: + // x = x | (x >> 1); + // x = x | (x >> 2); + // ... + // x = x | (x >>16); + // x = x | (x >>32); // for 64-bit input + // return popcount(~x); + // + // Ref: "Hacker's Delight" by Henry Warren + for (unsigned i = 0; (1U << i) <= (NumBitsPerElt / 2); ++i) { + SDValue Tmp = DAG.getConstant(1ULL << i, dl, ShVT); + Op = DAG.getNode(ISD::OR, dl, VT, Op, + DAG.getNode(ISD::SRL, dl, VT, Op, Tmp)); + } + Op = DAG.getNOT(dl, Op, VT); + Result = DAG.getNode(ISD::CTPOP, dl, VT, Op); + return true; +} + +bool TargetLowering::expandCTTZ(SDNode *Node, SDValue &Result, + SelectionDAG &DAG) const { + SDLoc dl(Node); + EVT VT = Node->getValueType(0); + SDValue Op = Node->getOperand(0); + unsigned NumBitsPerElt = VT.getScalarSizeInBits(); + + // If the non-ZERO_UNDEF version is supported we can use that instead. + if (Node->getOpcode() == ISD::CTTZ_ZERO_UNDEF && + isOperationLegalOrCustom(ISD::CTTZ, VT)) { + Result = DAG.getNode(ISD::CTTZ, dl, VT, Op); + return true; + } + + // If the ZERO_UNDEF version is supported use that and handle the zero case. + if (isOperationLegalOrCustom(ISD::CTTZ_ZERO_UNDEF, VT)) { + EVT SetCCVT = + getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue CTTZ = DAG.getNode(ISD::CTTZ_ZERO_UNDEF, dl, VT, Op); + SDValue Zero = DAG.getConstant(0, dl, VT); + SDValue SrcIsZero = DAG.getSetCC(dl, SetCCVT, Op, Zero, ISD::SETEQ); + Result = DAG.getNode(ISD::SELECT, dl, VT, SrcIsZero, + DAG.getConstant(NumBitsPerElt, dl, VT), CTTZ); + return true; + } + + // Only expand vector types if we have the appropriate vector bit operations. + if (VT.isVector() && (!isPowerOf2_32(NumBitsPerElt) || + (!isOperationLegalOrCustom(ISD::CTPOP, VT) && + !isOperationLegalOrCustom(ISD::CTLZ, VT)) || + !isOperationLegalOrCustom(ISD::SUB, VT) || + !isOperationLegalOrCustomOrPromote(ISD::AND, VT) || + !isOperationLegalOrCustomOrPromote(ISD::XOR, VT))) + return false; + + // for now, we use: { return popcount(~x & (x - 1)); } + // unless the target has ctlz but not ctpop, in which case we use: + // { return 32 - nlz(~x & (x-1)); } + // Ref: "Hacker's Delight" by Henry Warren + SDValue Tmp = DAG.getNode( + ISD::AND, dl, VT, DAG.getNOT(dl, Op, VT), + DAG.getNode(ISD::SUB, dl, VT, Op, DAG.getConstant(1, dl, VT))); + + // If ISD::CTLZ is legal and CTPOP isn't, then do that instead. + if (isOperationLegal(ISD::CTLZ, VT) && !isOperationLegal(ISD::CTPOP, VT)) { + Result = + DAG.getNode(ISD::SUB, dl, VT, DAG.getConstant(NumBitsPerElt, dl, VT), + DAG.getNode(ISD::CTLZ, dl, VT, Tmp)); + return true; + } + + Result = DAG.getNode(ISD::CTPOP, dl, VT, Tmp); + return true; +} + +bool TargetLowering::expandABS(SDNode *N, SDValue &Result, + SelectionDAG &DAG) const { + SDLoc dl(N); + EVT VT = N->getValueType(0); + EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout()); + SDValue Op = N->getOperand(0); + + // Only expand vector types if we have the appropriate vector operations. + if (VT.isVector() && (!isOperationLegalOrCustom(ISD::SRA, VT) || + !isOperationLegalOrCustom(ISD::ADD, VT) || + !isOperationLegalOrCustomOrPromote(ISD::XOR, VT))) + return false; + + SDValue Shift = + DAG.getNode(ISD::SRA, dl, VT, Op, + DAG.getConstant(VT.getScalarSizeInBits() - 1, dl, ShVT)); + SDValue Add = DAG.getNode(ISD::ADD, dl, VT, Op, Shift); + Result = DAG.getNode(ISD::XOR, dl, VT, Add, Shift); return true; } @@ -3876,8 +4752,6 @@ SDValue TargetLowering::scalarizeVectorLoad(LoadSDNode *LD, unsigned Stride = SrcEltVT.getSizeInBits() / 8; assert(SrcEltVT.isByteSized()); - EVT PtrVT = BasePTR.getValueType(); - SmallVector<SDValue, 8> Vals; SmallVector<SDValue, 8> LoadChains; @@ -3888,8 +4762,7 @@ SDValue TargetLowering::scalarizeVectorLoad(LoadSDNode *LD, SrcEltVT, MinAlign(LD->getAlignment(), Idx * Stride), LD->getMemOperand()->getFlags(), LD->getAAInfo()); - BasePTR = DAG.getNode(ISD::ADD, SL, PtrVT, BasePTR, - DAG.getConstant(Stride, SL, PtrVT)); + BasePTR = DAG.getObjectPtrOffset(SL, BasePTR, Stride); Vals.push_back(ScalarLoad.getValue(0)); LoadChains.push_back(ScalarLoad.getValue(1)); @@ -3989,7 +4862,8 @@ TargetLowering::expandUnalignedLoad(LoadSDNode *LD, SelectionDAG &DAG) const { if (VT.isFloatingPoint() || VT.isVector()) { EVT intVT = EVT::getIntegerVT(*DAG.getContext(), LoadedVT.getSizeInBits()); if (isTypeLegal(intVT) && isTypeLegal(LoadedVT)) { - if (!isOperationLegalOrCustom(ISD::LOAD, intVT)) { + if (!isOperationLegalOrCustom(ISD::LOAD, intVT) && + LoadedVT.isVector()) { // Scalarize the load and let the individual components be handled. SDValue Scalarized = scalarizeVectorLoad(LD, DAG); if (Scalarized->getOpcode() == ISD::MERGE_VALUES) @@ -4139,13 +5013,14 @@ SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST, EVT VT = Val.getValueType(); int Alignment = ST->getAlignment(); auto &MF = DAG.getMachineFunction(); + EVT MemVT = ST->getMemoryVT(); SDLoc dl(ST); - if (ST->getMemoryVT().isFloatingPoint() || - ST->getMemoryVT().isVector()) { + if (MemVT.isFloatingPoint() || MemVT.isVector()) { EVT intVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits()); if (isTypeLegal(intVT)) { - if (!isOperationLegalOrCustom(ISD::STORE, intVT)) { + if (!isOperationLegalOrCustom(ISD::STORE, intVT) && + MemVT.isVector()) { // Scalarize the store and let the individual components be handled. SDValue Result = scalarizeVectorStore(ST, DAG); @@ -4399,3 +5274,134 @@ SDValue TargetLowering::lowerCmpEqZeroToCtlzSrl(SDValue Op, } return SDValue(); } + +SDValue TargetLowering::expandAddSubSat(SDNode *Node, SelectionDAG &DAG) const { + unsigned Opcode = Node->getOpcode(); + SDValue LHS = Node->getOperand(0); + SDValue RHS = Node->getOperand(1); + EVT VT = LHS.getValueType(); + SDLoc dl(Node); + + // usub.sat(a, b) -> umax(a, b) - b + if (Opcode == ISD::USUBSAT && isOperationLegalOrCustom(ISD::UMAX, VT)) { + SDValue Max = DAG.getNode(ISD::UMAX, dl, VT, LHS, RHS); + return DAG.getNode(ISD::SUB, dl, VT, Max, RHS); + } + + if (VT.isVector()) { + // TODO: Consider not scalarizing here. + return SDValue(); + } + + unsigned OverflowOp; + switch (Opcode) { + case ISD::SADDSAT: + OverflowOp = ISD::SADDO; + break; + case ISD::UADDSAT: + OverflowOp = ISD::UADDO; + break; + case ISD::SSUBSAT: + OverflowOp = ISD::SSUBO; + break; + case ISD::USUBSAT: + OverflowOp = ISD::USUBO; + break; + default: + llvm_unreachable("Expected method to receive signed or unsigned saturation " + "addition or subtraction node."); + } + + assert(LHS.getValueType().isScalarInteger() && + "Expected operands to be integers. Vector of int arguments should " + "already be unrolled."); + assert(RHS.getValueType().isScalarInteger() && + "Expected operands to be integers. Vector of int arguments should " + "already be unrolled."); + assert(LHS.getValueType() == RHS.getValueType() && + "Expected both operands to be the same type"); + + unsigned BitWidth = LHS.getValueSizeInBits(); + EVT ResultType = LHS.getValueType(); + EVT BoolVT = + getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), ResultType); + SDValue Result = + DAG.getNode(OverflowOp, dl, DAG.getVTList(ResultType, BoolVT), LHS, RHS); + SDValue SumDiff = Result.getValue(0); + SDValue Overflow = Result.getValue(1); + SDValue Zero = DAG.getConstant(0, dl, ResultType); + + if (Opcode == ISD::UADDSAT) { + // Just need to check overflow for SatMax. + APInt MaxVal = APInt::getMaxValue(BitWidth); + SDValue SatMax = DAG.getConstant(MaxVal, dl, ResultType); + return DAG.getSelect(dl, ResultType, Overflow, SatMax, SumDiff); + } else if (Opcode == ISD::USUBSAT) { + // Just need to check overflow for SatMin. + APInt MinVal = APInt::getMinValue(BitWidth); + SDValue SatMin = DAG.getConstant(MinVal, dl, ResultType); + return DAG.getSelect(dl, ResultType, Overflow, SatMin, SumDiff); + } else { + // SatMax -> Overflow && SumDiff < 0 + // SatMin -> Overflow && SumDiff >= 0 + APInt MinVal = APInt::getSignedMinValue(BitWidth); + APInt MaxVal = APInt::getSignedMaxValue(BitWidth); + SDValue SatMin = DAG.getConstant(MinVal, dl, ResultType); + SDValue SatMax = DAG.getConstant(MaxVal, dl, ResultType); + SDValue SumNeg = DAG.getSetCC(dl, BoolVT, SumDiff, Zero, ISD::SETLT); + Result = DAG.getSelect(dl, ResultType, SumNeg, SatMax, SatMin); + return DAG.getSelect(dl, ResultType, Overflow, Result, SumDiff); + } +} + +SDValue +TargetLowering::getExpandedFixedPointMultiplication(SDNode *Node, + SelectionDAG &DAG) const { + assert(Node->getOpcode() == ISD::SMULFIX && "Expected opcode to be SMULFIX."); + assert(Node->getNumOperands() == 3 && + "Expected signed fixed point multiplication to have 3 operands."); + + SDLoc dl(Node); + SDValue LHS = Node->getOperand(0); + SDValue RHS = Node->getOperand(1); + assert(LHS.getValueType().isScalarInteger() && + "Expected operands to be integers. Vector of int arguments should " + "already be unrolled."); + assert(RHS.getValueType().isScalarInteger() && + "Expected operands to be integers. Vector of int arguments should " + "already be unrolled."); + assert(LHS.getValueType() == RHS.getValueType() && + "Expected both operands to be the same type"); + + unsigned Scale = Node->getConstantOperandVal(2); + EVT VT = LHS.getValueType(); + assert(Scale < VT.getScalarSizeInBits() && + "Expected scale to be less than the number of bits."); + + if (!Scale) + return DAG.getNode(ISD::MUL, dl, VT, LHS, RHS); + + // Get the upper and lower bits of the result. + SDValue Lo, Hi; + if (isOperationLegalOrCustom(ISD::SMUL_LOHI, VT)) { + SDValue Result = + DAG.getNode(ISD::SMUL_LOHI, dl, DAG.getVTList(VT, VT), LHS, RHS); + Lo = Result.getValue(0); + Hi = Result.getValue(1); + } else if (isOperationLegalOrCustom(ISD::MULHS, VT)) { + Lo = DAG.getNode(ISD::MUL, dl, VT, LHS, RHS); + Hi = DAG.getNode(ISD::MULHS, dl, VT, LHS, RHS); + } else { + report_fatal_error("Unable to expand signed fixed point multiplication."); + } + + // The result will need to be shifted right by the scale since both operands + // are scaled. The result is given to us in 2 halves, so we only want part of + // both in the result. + EVT ShiftTy = getShiftAmountTy(VT, DAG.getDataLayout()); + Lo = DAG.getNode(ISD::SRL, dl, VT, Lo, DAG.getConstant(Scale, dl, ShiftTy)); + Hi = DAG.getNode( + ISD::SHL, dl, VT, Hi, + DAG.getConstant(VT.getScalarSizeInBits() - Scale, dl, ShiftTy)); + return DAG.getNode(ISD::OR, dl, VT, Lo, Hi); +} |
