diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp | 353 | 
1 files changed, 297 insertions, 56 deletions
diff --git a/contrib/llvm-project/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/contrib/llvm-project/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp index 91a64d59e154..8b15bdb0aca3 100644 --- a/contrib/llvm-project/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ b/contrib/llvm-project/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -5940,62 +5940,6 @@ bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(    return false;  } -bool CombinerHelper::matchSelectToLogical(MachineInstr &MI, -                                          BuildFnTy &MatchInfo) { -  GSelect &Sel = cast<GSelect>(MI); -  Register DstReg = Sel.getReg(0); -  Register Cond = Sel.getCondReg(); -  Register TrueReg = Sel.getTrueReg(); -  Register FalseReg = Sel.getFalseReg(); - -  auto *TrueDef = getDefIgnoringCopies(TrueReg, MRI); -  auto *FalseDef = getDefIgnoringCopies(FalseReg, MRI); - -  const LLT CondTy = MRI.getType(Cond); -  const LLT OpTy = MRI.getType(TrueReg); -  if (CondTy != OpTy || OpTy.getScalarSizeInBits() != 1) -    return false; - -  // We have a boolean select. - -  // select Cond, Cond, F --> or Cond, F -  // select Cond, 1, F    --> or Cond, F -  auto MaybeCstTrue = isConstantOrConstantSplatVector(*TrueDef, MRI); -  if (Cond == TrueReg || (MaybeCstTrue && MaybeCstTrue->isOne())) { -    MatchInfo = [=](MachineIRBuilder &MIB) { -      MIB.buildOr(DstReg, Cond, FalseReg); -    }; -    return true; -  } - -  // select Cond, T, Cond --> and Cond, T -  // select Cond, T, 0    --> and Cond, T -  auto MaybeCstFalse = isConstantOrConstantSplatVector(*FalseDef, MRI); -  if (Cond == FalseReg || (MaybeCstFalse && MaybeCstFalse->isZero())) { -    MatchInfo = [=](MachineIRBuilder &MIB) { -      MIB.buildAnd(DstReg, Cond, TrueReg); -    }; -    return true; -  } - - // select Cond, T, 1 --> or (not Cond), T -  if (MaybeCstFalse && MaybeCstFalse->isOne()) { -    MatchInfo = [=](MachineIRBuilder &MIB) { -      MIB.buildOr(DstReg, MIB.buildNot(OpTy, Cond), TrueReg); -    }; -    return true; -  } - -  // select Cond, 0, F --> and (not Cond), F -  if (MaybeCstTrue && MaybeCstTrue->isZero()) { -    MatchInfo = [=](MachineIRBuilder &MIB) { -      MIB.buildAnd(DstReg, MIB.buildNot(OpTy, Cond), FalseReg); -    }; -    return true; -  } -  return false; -} -  bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI,                                              unsigned &IdxToPropagate) {    bool PropagateNaN; @@ -6318,3 +6262,300 @@ void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) {    MI.getOperand(2).setReg(LHSReg);    Observer.changedInstr(MI);  } + +bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) { +  LLT SrcTy = MRI.getType(Src); +  if (SrcTy.isFixedVector()) +    return isConstantSplatVector(Src, 1, AllowUndefs); +  if (SrcTy.isScalar()) { +    if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr) +      return true; +    auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI); +    return IConstant && IConstant->Value == 1; +  } +  return false; // scalable vector +} + +bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) { +  LLT SrcTy = MRI.getType(Src); +  if (SrcTy.isFixedVector()) +    return isConstantSplatVector(Src, 0, AllowUndefs); +  if (SrcTy.isScalar()) { +    if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr) +      return true; +    auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI); +    return IConstant && IConstant->Value == 0; +  } +  return false; // scalable vector +} + +// Ignores COPYs during conformance checks. +// FIXME scalable vectors. +bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue, +                                           bool AllowUndefs) { +  GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI); +  if (!BuildVector) +    return false; +  unsigned NumSources = BuildVector->getNumSources(); + +  for (unsigned I = 0; I < NumSources; ++I) { +    GImplicitDef *ImplicitDef = +        getOpcodeDef<GImplicitDef>(BuildVector->getSourceReg(I), MRI); +    if (ImplicitDef && AllowUndefs) +      continue; +    if (ImplicitDef && !AllowUndefs) +      return false; +    std::optional<ValueAndVReg> IConstant = +        getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI); +    if (IConstant && IConstant->Value == SplatValue) +      continue; +    return false; +  } +  return true; +} + +// Ignores COPYs during lookups. +// FIXME scalable vectors +std::optional<APInt> +CombinerHelper::getConstantOrConstantSplatVector(Register Src) { +  auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI); +  if (IConstant) +    return IConstant->Value; + +  GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI); +  if (!BuildVector) +    return std::nullopt; +  unsigned NumSources = BuildVector->getNumSources(); + +  std::optional<APInt> Value = std::nullopt; +  for (unsigned I = 0; I < NumSources; ++I) { +    std::optional<ValueAndVReg> IConstant = +        getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI); +    if (!IConstant) +      return std::nullopt; +    if (!Value) +      Value = IConstant->Value; +    else if (*Value != IConstant->Value) +      return std::nullopt; +  } +  return Value; +} + +// TODO: use knownbits to determine zeros +bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select, +                                              BuildFnTy &MatchInfo) { +  uint32_t Flags = Select->getFlags(); +  Register Dest = Select->getReg(0); +  Register Cond = Select->getCondReg(); +  Register True = Select->getTrueReg(); +  Register False = Select->getFalseReg(); +  LLT CondTy = MRI.getType(Select->getCondReg()); +  LLT TrueTy = MRI.getType(Select->getTrueReg()); + +  // We only do this combine for scalar boolean conditions. +  if (CondTy != LLT::scalar(1)) +    return false; + +  // Both are scalars. +  std::optional<ValueAndVReg> TrueOpt = +      getIConstantVRegValWithLookThrough(True, MRI); +  std::optional<ValueAndVReg> FalseOpt = +      getIConstantVRegValWithLookThrough(False, MRI); + +  if (!TrueOpt || !FalseOpt) +    return false; + +  APInt TrueValue = TrueOpt->Value; +  APInt FalseValue = FalseOpt->Value; + +  // select Cond, 1, 0 --> zext (Cond) +  if (TrueValue.isOne() && FalseValue.isZero()) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      B.buildZExtOrTrunc(Dest, Cond); +    }; +    return true; +  } + +  // select Cond, -1, 0 --> sext (Cond) +  if (TrueValue.isAllOnes() && FalseValue.isZero()) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      B.buildSExtOrTrunc(Dest, Cond); +    }; +    return true; +  } + +  // select Cond, 0, 1 --> zext (!Cond) +  if (TrueValue.isZero() && FalseValue.isOne()) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      Register Inner = MRI.createGenericVirtualRegister(CondTy); +      B.buildNot(Inner, Cond); +      B.buildZExtOrTrunc(Dest, Inner); +    }; +    return true; +  } + +  // select Cond, 0, -1 --> sext (!Cond) +  if (TrueValue.isZero() && FalseValue.isAllOnes()) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      Register Inner = MRI.createGenericVirtualRegister(CondTy); +      B.buildNot(Inner, Cond); +      B.buildSExtOrTrunc(Dest, Inner); +    }; +    return true; +  } + +  // select Cond, C1, C1-1 --> add (zext Cond), C1-1 +  if (TrueValue - 1 == FalseValue) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      Register Inner = MRI.createGenericVirtualRegister(TrueTy); +      B.buildZExtOrTrunc(Inner, Cond); +      B.buildAdd(Dest, Inner, False); +    }; +    return true; +  } + +  // select Cond, C1, C1+1 --> add (sext Cond), C1+1 +  if (TrueValue + 1 == FalseValue) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      Register Inner = MRI.createGenericVirtualRegister(TrueTy); +      B.buildSExtOrTrunc(Inner, Cond); +      B.buildAdd(Dest, Inner, False); +    }; +    return true; +  } + +  // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2) +  if (TrueValue.isPowerOf2() && FalseValue.isZero()) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      Register Inner = MRI.createGenericVirtualRegister(TrueTy); +      B.buildZExtOrTrunc(Inner, Cond); +      // The shift amount must be scalar. +      LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy; +      auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2()); +      B.buildShl(Dest, Inner, ShAmtC, Flags); +    }; +    return true; +  } +  // select Cond, -1, C --> or (sext Cond), C +  if (TrueValue.isAllOnes()) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      Register Inner = MRI.createGenericVirtualRegister(TrueTy); +      B.buildSExtOrTrunc(Inner, Cond); +      B.buildOr(Dest, Inner, False, Flags); +    }; +    return true; +  } + +  // select Cond, C, -1 --> or (sext (not Cond)), C +  if (FalseValue.isAllOnes()) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      Register Not = MRI.createGenericVirtualRegister(CondTy); +      B.buildNot(Not, Cond); +      Register Inner = MRI.createGenericVirtualRegister(TrueTy); +      B.buildSExtOrTrunc(Inner, Not); +      B.buildOr(Dest, Inner, True, Flags); +    }; +    return true; +  } + +  return false; +} + +// TODO: use knownbits to determine zeros +bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select, +                                              BuildFnTy &MatchInfo) { +  uint32_t Flags = Select->getFlags(); +  Register DstReg = Select->getReg(0); +  Register Cond = Select->getCondReg(); +  Register True = Select->getTrueReg(); +  Register False = Select->getFalseReg(); +  LLT CondTy = MRI.getType(Select->getCondReg()); +  LLT TrueTy = MRI.getType(Select->getTrueReg()); + +  // Boolean or fixed vector of booleans. +  if (CondTy.isScalableVector() || +      (CondTy.isFixedVector() && +       CondTy.getElementType().getScalarSizeInBits() != 1) || +      CondTy.getScalarSizeInBits() != 1) +    return false; + +  if (CondTy != TrueTy) +    return false; + +  // select Cond, Cond, F --> or Cond, F +  // select Cond, 1, F    --> or Cond, F +  if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      Register Ext = MRI.createGenericVirtualRegister(TrueTy); +      B.buildZExtOrTrunc(Ext, Cond); +      B.buildOr(DstReg, Ext, False, Flags); +    }; +    return true; +  } + +  // select Cond, T, Cond --> and Cond, T +  // select Cond, T, 0    --> and Cond, T +  if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      Register Ext = MRI.createGenericVirtualRegister(TrueTy); +      B.buildZExtOrTrunc(Ext, Cond); +      B.buildAnd(DstReg, Ext, True); +    }; +    return true; +  } + +  // select Cond, T, 1 --> or (not Cond), T +  if (isOneOrOneSplat(False, /* AllowUndefs */ true)) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      // First the not. +      Register Inner = MRI.createGenericVirtualRegister(CondTy); +      B.buildNot(Inner, Cond); +      // Then an ext to match the destination register. +      Register Ext = MRI.createGenericVirtualRegister(TrueTy); +      B.buildZExtOrTrunc(Ext, Inner); +      B.buildOr(DstReg, Ext, True, Flags); +    }; +    return true; +  } + +  // select Cond, 0, F --> and (not Cond), F +  if (isZeroOrZeroSplat(True, /* AllowUndefs */ true)) { +    MatchInfo = [=](MachineIRBuilder &B) { +      B.setInstrAndDebugLoc(*Select); +      // First the not. +      Register Inner = MRI.createGenericVirtualRegister(CondTy); +      B.buildNot(Inner, Cond); +      // Then an ext to match the destination register. +      Register Ext = MRI.createGenericVirtualRegister(TrueTy); +      B.buildZExtOrTrunc(Ext, Inner); +      B.buildAnd(DstReg, Ext, False); +    }; +    return true; +  } + +  return false; +} + +bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) { +  GSelect *Select = cast<GSelect>(&MI); + +  if (tryFoldSelectOfConstants(Select, MatchInfo)) +    return true; + +  if (tryFoldBoolSelectToLogic(Select, MatchInfo)) +    return true; + +  return false; +}  | 
