aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp122
1 files changed, 95 insertions, 27 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp b/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp
index 1e4b1361f98a..5a28240ea9e2 100644
--- a/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -7371,7 +7371,7 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
/// index.
static int getUnderlyingExtractedFromVec(SDValue &ExtractedFromVec,
SDValue ExtIdx) {
- int Idx = cast<ConstantSDNode>(ExtIdx)->getZExtValue();
+ int Idx = ExtIdx->getAsZExtVal();
if (!isa<ShuffleVectorSDNode>(ExtractedFromVec))
return Idx;
@@ -7475,10 +7475,12 @@ static SDValue buildFromShuffleMostly(SDValue Op, SelectionDAG &DAG) {
static SDValue LowerBUILD_VECTORvXbf16(SDValue Op, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
MVT VT = Op.getSimpleValueType();
- MVT IVT = VT.changeVectorElementTypeToInteger();
+ MVT IVT =
+ VT.changeVectorElementType(Subtarget.hasFP16() ? MVT::f16 : MVT::i16);
SmallVector<SDValue, 16> NewOps;
for (unsigned I = 0, E = Op.getNumOperands(); I != E; ++I)
- NewOps.push_back(DAG.getBitcast(MVT::i16, Op.getOperand(I)));
+ NewOps.push_back(DAG.getBitcast(Subtarget.hasFP16() ? MVT::f16 : MVT::i16,
+ Op.getOperand(I)));
SDValue Res = DAG.getNode(ISD::BUILD_VECTOR, SDLoc(), IVT, NewOps);
return DAG.getBitcast(VT, Res);
}
@@ -8793,7 +8795,7 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
MachineFunction &MF = DAG.getMachineFunction();
MachinePointerInfo MPI = MachinePointerInfo::getConstantPool(MF);
SDValue Ld = DAG.getLoad(VT, dl, DAG.getEntryNode(), LegalDAGConstVec, MPI);
- unsigned InsertC = cast<ConstantSDNode>(InsIndex)->getZExtValue();
+ unsigned InsertC = InsIndex->getAsZExtVal();
unsigned NumEltsInLow128Bits = 128 / VT.getScalarSizeInBits();
if (InsertC < NumEltsInLow128Bits)
return DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Ld, VarElt, InsIndex);
@@ -14369,6 +14371,13 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
+ if (VT == MVT::v8bf16) {
+ V1 = DAG.getBitcast(MVT::v8i16, V1);
+ V2 = DAG.getBitcast(MVT::v8i16, V2);
+ return DAG.getBitcast(VT,
+ DAG.getVectorShuffle(MVT::v8i16, DL, V1, V2, Mask));
+ }
+
switch (VT.SimpleTy) {
case MVT::v2i64:
return lowerV2I64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
@@ -17096,14 +17105,14 @@ static SDValue lower512BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return splitAndLowerShuffle(DL, VT, V1, V2, Mask, DAG, /*SimpleOnly*/ false);
}
- if (VT == MVT::v32f16) {
+ if (VT == MVT::v32f16 || VT == MVT::v32bf16) {
if (!Subtarget.hasBWI())
return splitAndLowerShuffle(DL, VT, V1, V2, Mask, DAG,
/*SimpleOnly*/ false);
V1 = DAG.getBitcast(MVT::v32i16, V1);
V2 = DAG.getBitcast(MVT::v32i16, V2);
- return DAG.getBitcast(MVT::v32f16,
+ return DAG.getBitcast(VT,
DAG.getVectorShuffle(MVT::v32i16, DL, V1, V2, Mask));
}
@@ -17747,7 +17756,7 @@ static SDValue LowerEXTRACT_VECTOR_ELT_SSE4(SDValue Op, SelectionDAG &DAG) {
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32,
DAG.getBitcast(MVT::v4i32, Vec), Idx));
- unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue();
+ unsigned IdxVal = Idx->getAsZExtVal();
SDValue Extract = DAG.getNode(X86ISD::PEXTRB, dl, MVT::i32, Vec,
DAG.getTargetConstant(IdxVal, dl, MVT::i8));
return DAG.getNode(ISD::TRUNCATE, dl, VT, Extract);
@@ -21515,9 +21524,8 @@ SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
RTLIB::Libcall LC =
RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16);
SDValue Res =
- makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
- return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16,
- DAG.getBitcast(MVT::i32, Res));
+ makeLibCall(DAG, LC, MVT::f16, Op.getOperand(0), CallOptions, DL).first;
+ return DAG.getBitcast(MVT::i16, Res);
}
/// Depending on uarch and/or optimizing for size, we might prefer to use a
@@ -24061,7 +24069,7 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
// a >= b ? -1 : 0 -> RES = setcc_carry
// a >= b ? 0 : -1 -> RES = ~setcc_carry
if (Cond.getOpcode() == X86ISD::SUB) {
- unsigned CondCode = cast<ConstantSDNode>(CC)->getZExtValue();
+ unsigned CondCode = CC->getAsZExtVal();
if ((CondCode == X86::COND_AE || CondCode == X86::COND_B) &&
(isAllOnesConstant(Op1) || isAllOnesConstant(Op2)) &&
@@ -25359,8 +25367,7 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
if (IntrData->Type == INTR_TYPE_3OP_IMM8 &&
Src3.getValueType() != MVT::i8) {
- Src3 = DAG.getTargetConstant(
- cast<ConstantSDNode>(Src3)->getZExtValue() & 0xff, dl, MVT::i8);
+ Src3 = DAG.getTargetConstant(Src3->getAsZExtVal() & 0xff, dl, MVT::i8);
}
// We specify 2 possible opcodes for intrinsics with rounding modes.
@@ -25385,8 +25392,7 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
assert(Op.getOperand(4)->getOpcode() == ISD::TargetConstant);
SDValue Src4 = Op.getOperand(4);
if (Src4.getValueType() != MVT::i8) {
- Src4 = DAG.getTargetConstant(
- cast<ConstantSDNode>(Src4)->getZExtValue() & 0xff, dl, MVT::i8);
+ Src4 = DAG.getTargetConstant(Src4->getAsZExtVal() & 0xff, dl, MVT::i8);
}
return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(),
@@ -26788,7 +26794,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
{Chain, Op1, Op2, Size}, VT, MMO);
Chain = Res.getValue(1);
Res = DAG.getZExtOrTrunc(getSETCC(X86::COND_B, Res, DL, DAG), DL, VT);
- unsigned Imm = cast<ConstantSDNode>(Op2)->getZExtValue();
+ unsigned Imm = Op2->getAsZExtVal();
if (Imm)
Res = DAG.getNode(ISD::SHL, DL, VT, Res,
DAG.getShiftAmountConstant(Imm, VT, DL));
@@ -40221,6 +40227,34 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
}
return SDValue();
}
+ case X86ISD::SHUF128: {
+ // If we're permuting the upper 256-bits subvectors of a concatenation, then
+ // see if we can peek through and access the subvector directly.
+ if (VT.is512BitVector()) {
+ // 512-bit mask uses 4 x i2 indices - if the msb is always set then only the
+ // upper subvector is used.
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+ uint64_t Mask = N->getConstantOperandVal(2);
+ SmallVector<SDValue> LHSOps, RHSOps;
+ SDValue NewLHS, NewRHS;
+ if ((Mask & 0x0A) == 0x0A &&
+ collectConcatOps(LHS.getNode(), LHSOps, DAG) && LHSOps.size() == 2) {
+ NewLHS = widenSubVector(LHSOps[1], false, Subtarget, DAG, DL, 512);
+ Mask &= ~0x0A;
+ }
+ if ((Mask & 0xA0) == 0xA0 &&
+ collectConcatOps(RHS.getNode(), RHSOps, DAG) && RHSOps.size() == 2) {
+ NewRHS = widenSubVector(RHSOps[1], false, Subtarget, DAG, DL, 512);
+ Mask &= ~0xA0;
+ }
+ if (NewLHS || NewRHS)
+ return DAG.getNode(X86ISD::SHUF128, DL, VT, NewLHS ? NewLHS : LHS,
+ NewRHS ? NewRHS : RHS,
+ DAG.getTargetConstant(Mask, DL, MVT::i8));
+ }
+ return SDValue();
+ }
case X86ISD::VPERM2X128: {
// Fold vperm2x128(bitcast(x),bitcast(y),c) -> bitcast(vperm2x128(x,y,c)).
SDValue LHS = N->getOperand(0);
@@ -41320,6 +41354,20 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
return TLO.CombineTo(Op, Src);
break;
}
+ case X86ISD::VZEXT_LOAD: {
+ // If upper demanded elements are not demanded then simplify to a
+ // scalar_to_vector(load()).
+ MVT SVT = VT.getSimpleVT().getVectorElementType();
+ if (DemandedElts == 1 && Op.getValue(1).use_empty() && isTypeLegal(SVT)) {
+ SDLoc DL(Op);
+ auto *Mem = cast<MemSDNode>(Op);
+ SDValue Elt = TLO.DAG.getLoad(SVT, DL, Mem->getChain(), Mem->getBasePtr(),
+ Mem->getMemOperand());
+ SDValue Vec = TLO.DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Elt);
+ return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Vec));
+ }
+ break;
+ }
case X86ISD::VBROADCAST: {
SDValue Src = Op.getOperand(0);
MVT SrcVT = Src.getSimpleValueType();
@@ -41795,7 +41843,7 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
- unsigned ShAmt = cast<ConstantSDNode>(Op1)->getZExtValue();
+ unsigned ShAmt = Op1->getAsZExtVal();
if (ShAmt >= BitWidth)
break;
@@ -42580,7 +42628,7 @@ static SDValue combinevXi1ConstantToInteger(SDValue Op, SelectionDAG &DAG) {
APInt Imm(SrcVT.getVectorNumElements(), 0);
for (unsigned Idx = 0, e = Op.getNumOperands(); Idx < e; ++Idx) {
SDValue In = Op.getOperand(Idx);
- if (!In.isUndef() && (cast<ConstantSDNode>(In)->getZExtValue() & 0x1))
+ if (!In.isUndef() && (In->getAsZExtVal() & 0x1))
Imm.setBit(Idx);
}
EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), Imm.getBitWidth());
@@ -49931,18 +49979,17 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
SDValue Ptr = Ld->getBasePtr();
SDValue Chain = Ld->getChain();
for (SDNode *User : Chain->uses()) {
- if (User != N &&
+ auto *UserLd = dyn_cast<MemSDNode>(User);
+ if (User != N && UserLd &&
(User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
User->getOpcode() == X86ISD::VBROADCAST_LOAD ||
ISD::isNormalLoad(User)) &&
- cast<MemSDNode>(User)->getChain() == Chain &&
- !User->hasAnyUseOfValue(1) &&
+ UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
User->getValueSizeInBits(0).getFixedValue() >
RegVT.getFixedSizeInBits()) {
if (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
- cast<MemSDNode>(User)->getBasePtr() == Ptr &&
- cast<MemSDNode>(User)->getMemoryVT().getSizeInBits() ==
- MemVT.getSizeInBits()) {
+ UserLd->getBasePtr() == Ptr &&
+ UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits()) {
SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
RegVT.getSizeInBits());
Extract = DAG.getBitcast(RegVT, Extract);
@@ -49961,7 +50008,7 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
// See if we are loading a constant that matches in the lower
// bits of a longer constant (but from a different constant pool ptr).
EVT UserVT = User->getValueType(0);
- SDValue UserPtr = cast<MemSDNode>(User)->getBasePtr();
+ SDValue UserPtr = UserLd->getBasePtr();
const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
if (LdC && UserC && UserPtr != Ptr) {
@@ -53258,7 +53305,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
if (Index.getOpcode() == ISD::ADD &&
Index.getValueType().getVectorElementType() == PtrVT &&
isa<ConstantSDNode>(Scale)) {
- uint64_t ScaleAmt = cast<ConstantSDNode>(Scale)->getZExtValue();
+ uint64_t ScaleAmt = Scale->getAsZExtVal();
if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(1))) {
BitVector UndefElts;
if (ConstantSDNode *C = BV->getConstantSplatNode(&UndefElts)) {
@@ -54572,6 +54619,14 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
Op0.getValueType() == cast<MemSDNode>(SrcVec)->getMemoryVT())
return Op0.getOperand(0);
}
+
+ // concat_vectors(permq(x),permq(x)) -> permq(concat_vectors(x,x))
+ if (Op0.getOpcode() == X86ISD::VPERMI && Subtarget.useAVX512Regs() &&
+ !X86::mayFoldLoad(Op0.getOperand(0), Subtarget))
+ return DAG.getNode(Op0.getOpcode(), DL, VT,
+ DAG.getNode(ISD::CONCAT_VECTORS, DL, VT,
+ Op0.getOperand(0), Op0.getOperand(0)),
+ Op0.getOperand(1));
}
// concat(extract_subvector(v0,c0), extract_subvector(v1,c1)) -> vperm2x128.
@@ -54979,6 +55034,19 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
ConcatSubOperand(VT, Ops, 1), Op0.getOperand(2));
}
break;
+ case X86ISD::BLENDI:
+ if (NumOps == 2 && VT.is512BitVector() && Subtarget.useBWIRegs()) {
+ uint64_t Mask0 = Ops[0].getConstantOperandVal(2);
+ uint64_t Mask1 = Ops[1].getConstantOperandVal(2);
+ uint64_t Mask = (Mask1 << (VT.getVectorNumElements() / 2)) | Mask0;
+ MVT MaskSVT = MVT::getIntegerVT(VT.getVectorNumElements());
+ MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
+ SDValue Sel =
+ DAG.getBitcast(MaskVT, DAG.getConstant(Mask, DL, MaskSVT));
+ return DAG.getSelect(DL, VT, Sel, ConcatSubOperand(VT, Ops, 1),
+ ConcatSubOperand(VT, Ops, 0));
+ }
+ break;
case ISD::VSELECT:
if (!IsSplat && Subtarget.hasAVX512() &&
(VT.is256BitVector() ||
@@ -57602,7 +57670,7 @@ X86TargetLowering::getStackProbeSize(const MachineFunction &MF) const {
}
Align X86TargetLowering::getPrefLoopAlignment(MachineLoop *ML) const {
- if (ML->isInnermost() &&
+ if (ML && ML->isInnermost() &&
ExperimentalPrefInnermostLoopAlignment.getNumOccurrences())
return Align(1ULL << ExperimentalPrefInnermostLoopAlignment);
return TargetLowering::getPrefLoopAlignment();