diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/InstCombine')
16 files changed, 39981 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp new file mode 100644 index 000000000000..f4d8b79a5311 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -0,0 +1,2493 @@ +//===- InstCombineAddSub.cpp ------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the visit functions for add, fadd, sub, and fsub. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/AlignOf.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +#include <cassert> +#include <utility> + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "instcombine" + +namespace { + + /// Class representing coefficient of floating-point addend. + /// This class needs to be highly efficient, which is especially true for + /// the constructor. As of I write this comment, the cost of the default + /// constructor is merely 4-byte-store-zero (Assuming compiler is able to + /// perform write-merging). + /// + class FAddendCoef { + public: + // The constructor has to initialize a APFloat, which is unnecessary for + // most addends which have coefficient either 1 or -1. So, the constructor + // is expensive. In order to avoid the cost of the constructor, we should + // reuse some instances whenever possible. The pre-created instances + // FAddCombine::Add[0-5] embodies this idea. + FAddendCoef() = default; + ~FAddendCoef(); + + // If possible, don't define operator+/operator- etc because these + // operators inevitably call FAddendCoef's constructor which is not cheap. + void operator=(const FAddendCoef &A); + void operator+=(const FAddendCoef &A); + void operator*=(const FAddendCoef &S); + + void set(short C) { + assert(!insaneIntVal(C) && "Insane coefficient"); + IsFp = false; IntVal = C; + } + + void set(const APFloat& C); + + void negate(); + + bool isZero() const { return isInt() ? !IntVal : getFpVal().isZero(); } + Value *getValue(Type *) const; + + bool isOne() const { return isInt() && IntVal == 1; } + bool isTwo() const { return isInt() && IntVal == 2; } + bool isMinusOne() const { return isInt() && IntVal == -1; } + bool isMinusTwo() const { return isInt() && IntVal == -2; } + + private: + bool insaneIntVal(int V) { return V > 4 || V < -4; } + + APFloat *getFpValPtr() { return reinterpret_cast<APFloat *>(&FpValBuf); } + + const APFloat *getFpValPtr() const { + return reinterpret_cast<const APFloat *>(&FpValBuf); + } + + const APFloat &getFpVal() const { + assert(IsFp && BufHasFpVal && "Incorret state"); + return *getFpValPtr(); + } + + APFloat &getFpVal() { + assert(IsFp && BufHasFpVal && "Incorret state"); + return *getFpValPtr(); + } + + bool isInt() const { return !IsFp; } + + // If the coefficient is represented by an integer, promote it to a + // floating point. + void convertToFpType(const fltSemantics &Sem); + + // Construct an APFloat from a signed integer. + // TODO: We should get rid of this function when APFloat can be constructed + // from an *SIGNED* integer. + APFloat createAPFloatFromInt(const fltSemantics &Sem, int Val); + + bool IsFp = false; + + // True iff FpValBuf contains an instance of APFloat. + bool BufHasFpVal = false; + + // The integer coefficient of an individual addend is either 1 or -1, + // and we try to simplify at most 4 addends from neighboring at most + // two instructions. So the range of <IntVal> falls in [-4, 4]. APInt + // is overkill of this end. + short IntVal = 0; + + AlignedCharArrayUnion<APFloat> FpValBuf; + }; + + /// FAddend is used to represent floating-point addend. An addend is + /// represented as <C, V>, where the V is a symbolic value, and C is a + /// constant coefficient. A constant addend is represented as <C, 0>. + class FAddend { + public: + FAddend() = default; + + void operator+=(const FAddend &T) { + assert((Val == T.Val) && "Symbolic-values disagree"); + Coeff += T.Coeff; + } + + Value *getSymVal() const { return Val; } + const FAddendCoef &getCoef() const { return Coeff; } + + bool isConstant() const { return Val == nullptr; } + bool isZero() const { return Coeff.isZero(); } + + void set(short Coefficient, Value *V) { + Coeff.set(Coefficient); + Val = V; + } + void set(const APFloat &Coefficient, Value *V) { + Coeff.set(Coefficient); + Val = V; + } + void set(const ConstantFP *Coefficient, Value *V) { + Coeff.set(Coefficient->getValueAPF()); + Val = V; + } + + void negate() { Coeff.negate(); } + + /// Drill down the U-D chain one step to find the definition of V, and + /// try to break the definition into one or two addends. + static unsigned drillValueDownOneStep(Value* V, FAddend &A0, FAddend &A1); + + /// Similar to FAddend::drillDownOneStep() except that the value being + /// splitted is the addend itself. + unsigned drillAddendDownOneStep(FAddend &Addend0, FAddend &Addend1) const; + + private: + void Scale(const FAddendCoef& ScaleAmt) { Coeff *= ScaleAmt; } + + // This addend has the value of "Coeff * Val". + Value *Val = nullptr; + FAddendCoef Coeff; + }; + + /// FAddCombine is the class for optimizing an unsafe fadd/fsub along + /// with its neighboring at most two instructions. + /// + class FAddCombine { + public: + FAddCombine(InstCombiner::BuilderTy &B) : Builder(B) {} + + Value *simplify(Instruction *FAdd); + + private: + using AddendVect = SmallVector<const FAddend *, 4>; + + Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota); + + /// Convert given addend to a Value + Value *createAddendVal(const FAddend &A, bool& NeedNeg); + + /// Return the number of instructions needed to emit the N-ary addition. + unsigned calcInstrNumber(const AddendVect& Vect); + + Value *createFSub(Value *Opnd0, Value *Opnd1); + Value *createFAdd(Value *Opnd0, Value *Opnd1); + Value *createFMul(Value *Opnd0, Value *Opnd1); + Value *createFNeg(Value *V); + Value *createNaryFAdd(const AddendVect& Opnds, unsigned InstrQuota); + void createInstPostProc(Instruction *NewInst, bool NoNumber = false); + + // Debugging stuff are clustered here. + #ifndef NDEBUG + unsigned CreateInstrNum; + void initCreateInstNum() { CreateInstrNum = 0; } + void incCreateInstNum() { CreateInstrNum++; } + #else + void initCreateInstNum() {} + void incCreateInstNum() {} + #endif + + InstCombiner::BuilderTy &Builder; + Instruction *Instr = nullptr; + }; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// +// Implementation of +// {FAddendCoef, FAddend, FAddition, FAddCombine}. +// +//===----------------------------------------------------------------------===// +FAddendCoef::~FAddendCoef() { + if (BufHasFpVal) + getFpValPtr()->~APFloat(); +} + +void FAddendCoef::set(const APFloat& C) { + APFloat *P = getFpValPtr(); + + if (isInt()) { + // As the buffer is meanless byte stream, we cannot call + // APFloat::operator=(). + new(P) APFloat(C); + } else + *P = C; + + IsFp = BufHasFpVal = true; +} + +void FAddendCoef::convertToFpType(const fltSemantics &Sem) { + if (!isInt()) + return; + + APFloat *P = getFpValPtr(); + if (IntVal > 0) + new(P) APFloat(Sem, IntVal); + else { + new(P) APFloat(Sem, 0 - IntVal); + P->changeSign(); + } + IsFp = BufHasFpVal = true; +} + +APFloat FAddendCoef::createAPFloatFromInt(const fltSemantics &Sem, int Val) { + if (Val >= 0) + return APFloat(Sem, Val); + + APFloat T(Sem, 0 - Val); + T.changeSign(); + + return T; +} + +void FAddendCoef::operator=(const FAddendCoef &That) { + if (That.isInt()) + set(That.IntVal); + else + set(That.getFpVal()); +} + +void FAddendCoef::operator+=(const FAddendCoef &That) { + RoundingMode RndMode = RoundingMode::NearestTiesToEven; + if (isInt() == That.isInt()) { + if (isInt()) + IntVal += That.IntVal; + else + getFpVal().add(That.getFpVal(), RndMode); + return; + } + + if (isInt()) { + const APFloat &T = That.getFpVal(); + convertToFpType(T.getSemantics()); + getFpVal().add(T, RndMode); + return; + } + + APFloat &T = getFpVal(); + T.add(createAPFloatFromInt(T.getSemantics(), That.IntVal), RndMode); +} + +void FAddendCoef::operator*=(const FAddendCoef &That) { + if (That.isOne()) + return; + + if (That.isMinusOne()) { + negate(); + return; + } + + if (isInt() && That.isInt()) { + int Res = IntVal * (int)That.IntVal; + assert(!insaneIntVal(Res) && "Insane int value"); + IntVal = Res; + return; + } + + const fltSemantics &Semantic = + isInt() ? That.getFpVal().getSemantics() : getFpVal().getSemantics(); + + if (isInt()) + convertToFpType(Semantic); + APFloat &F0 = getFpVal(); + + if (That.isInt()) + F0.multiply(createAPFloatFromInt(Semantic, That.IntVal), + APFloat::rmNearestTiesToEven); + else + F0.multiply(That.getFpVal(), APFloat::rmNearestTiesToEven); +} + +void FAddendCoef::negate() { + if (isInt()) + IntVal = 0 - IntVal; + else + getFpVal().changeSign(); +} + +Value *FAddendCoef::getValue(Type *Ty) const { + return isInt() ? + ConstantFP::get(Ty, float(IntVal)) : + ConstantFP::get(Ty->getContext(), getFpVal()); +} + +// The definition of <Val> Addends +// ========================================= +// A + B <1, A>, <1,B> +// A - B <1, A>, <1,B> +// 0 - B <-1, B> +// C * A, <C, A> +// A + C <1, A> <C, NULL> +// 0 +/- 0 <0, NULL> (corner case) +// +// Legend: A and B are not constant, C is constant +unsigned FAddend::drillValueDownOneStep + (Value *Val, FAddend &Addend0, FAddend &Addend1) { + Instruction *I = nullptr; + if (!Val || !(I = dyn_cast<Instruction>(Val))) + return 0; + + unsigned Opcode = I->getOpcode(); + + if (Opcode == Instruction::FAdd || Opcode == Instruction::FSub) { + ConstantFP *C0, *C1; + Value *Opnd0 = I->getOperand(0); + Value *Opnd1 = I->getOperand(1); + if ((C0 = dyn_cast<ConstantFP>(Opnd0)) && C0->isZero()) + Opnd0 = nullptr; + + if ((C1 = dyn_cast<ConstantFP>(Opnd1)) && C1->isZero()) + Opnd1 = nullptr; + + if (Opnd0) { + if (!C0) + Addend0.set(1, Opnd0); + else + Addend0.set(C0, nullptr); + } + + if (Opnd1) { + FAddend &Addend = Opnd0 ? Addend1 : Addend0; + if (!C1) + Addend.set(1, Opnd1); + else + Addend.set(C1, nullptr); + if (Opcode == Instruction::FSub) + Addend.negate(); + } + + if (Opnd0 || Opnd1) + return Opnd0 && Opnd1 ? 2 : 1; + + // Both operands are zero. Weird! + Addend0.set(APFloat(C0->getValueAPF().getSemantics()), nullptr); + return 1; + } + + if (I->getOpcode() == Instruction::FMul) { + Value *V0 = I->getOperand(0); + Value *V1 = I->getOperand(1); + if (ConstantFP *C = dyn_cast<ConstantFP>(V0)) { + Addend0.set(C, V1); + return 1; + } + + if (ConstantFP *C = dyn_cast<ConstantFP>(V1)) { + Addend0.set(C, V0); + return 1; + } + } + + return 0; +} + +// Try to break *this* addend into two addends. e.g. Suppose this addend is +// <2.3, V>, and V = X + Y, by calling this function, we obtain two addends, +// i.e. <2.3, X> and <2.3, Y>. +unsigned FAddend::drillAddendDownOneStep + (FAddend &Addend0, FAddend &Addend1) const { + if (isConstant()) + return 0; + + unsigned BreakNum = FAddend::drillValueDownOneStep(Val, Addend0, Addend1); + if (!BreakNum || Coeff.isOne()) + return BreakNum; + + Addend0.Scale(Coeff); + + if (BreakNum == 2) + Addend1.Scale(Coeff); + + return BreakNum; +} + +Value *FAddCombine::simplify(Instruction *I) { + assert(I->hasAllowReassoc() && I->hasNoSignedZeros() && + "Expected 'reassoc'+'nsz' instruction"); + + // Currently we are not able to handle vector type. + if (I->getType()->isVectorTy()) + return nullptr; + + assert((I->getOpcode() == Instruction::FAdd || + I->getOpcode() == Instruction::FSub) && "Expect add/sub"); + + // Save the instruction before calling other member-functions. + Instr = I; + + FAddend Opnd0, Opnd1, Opnd0_0, Opnd0_1, Opnd1_0, Opnd1_1; + + unsigned OpndNum = FAddend::drillValueDownOneStep(I, Opnd0, Opnd1); + + // Step 1: Expand the 1st addend into Opnd0_0 and Opnd0_1. + unsigned Opnd0_ExpNum = 0; + unsigned Opnd1_ExpNum = 0; + + if (!Opnd0.isConstant()) + Opnd0_ExpNum = Opnd0.drillAddendDownOneStep(Opnd0_0, Opnd0_1); + + // Step 2: Expand the 2nd addend into Opnd1_0 and Opnd1_1. + if (OpndNum == 2 && !Opnd1.isConstant()) + Opnd1_ExpNum = Opnd1.drillAddendDownOneStep(Opnd1_0, Opnd1_1); + + // Step 3: Try to optimize Opnd0_0 + Opnd0_1 + Opnd1_0 + Opnd1_1 + if (Opnd0_ExpNum && Opnd1_ExpNum) { + AddendVect AllOpnds; + AllOpnds.push_back(&Opnd0_0); + AllOpnds.push_back(&Opnd1_0); + if (Opnd0_ExpNum == 2) + AllOpnds.push_back(&Opnd0_1); + if (Opnd1_ExpNum == 2) + AllOpnds.push_back(&Opnd1_1); + + // Compute instruction quota. We should save at least one instruction. + unsigned InstQuota = 0; + + Value *V0 = I->getOperand(0); + Value *V1 = I->getOperand(1); + InstQuota = ((!isa<Constant>(V0) && V0->hasOneUse()) && + (!isa<Constant>(V1) && V1->hasOneUse())) ? 2 : 1; + + if (Value *R = simplifyFAdd(AllOpnds, InstQuota)) + return R; + } + + if (OpndNum != 2) { + // The input instruction is : "I=0.0 +/- V". If the "V" were able to be + // splitted into two addends, say "V = X - Y", the instruction would have + // been optimized into "I = Y - X" in the previous steps. + // + const FAddendCoef &CE = Opnd0.getCoef(); + return CE.isOne() ? Opnd0.getSymVal() : nullptr; + } + + // step 4: Try to optimize Opnd0 + Opnd1_0 [+ Opnd1_1] + if (Opnd1_ExpNum) { + AddendVect AllOpnds; + AllOpnds.push_back(&Opnd0); + AllOpnds.push_back(&Opnd1_0); + if (Opnd1_ExpNum == 2) + AllOpnds.push_back(&Opnd1_1); + + if (Value *R = simplifyFAdd(AllOpnds, 1)) + return R; + } + + // step 5: Try to optimize Opnd1 + Opnd0_0 [+ Opnd0_1] + if (Opnd0_ExpNum) { + AddendVect AllOpnds; + AllOpnds.push_back(&Opnd1); + AllOpnds.push_back(&Opnd0_0); + if (Opnd0_ExpNum == 2) + AllOpnds.push_back(&Opnd0_1); + + if (Value *R = simplifyFAdd(AllOpnds, 1)) + return R; + } + + return nullptr; +} + +Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { + unsigned AddendNum = Addends.size(); + assert(AddendNum <= 4 && "Too many addends"); + + // For saving intermediate results; + unsigned NextTmpIdx = 0; + FAddend TmpResult[3]; + + // Simplified addends are placed <SimpVect>. + AddendVect SimpVect; + + // The outer loop works on one symbolic-value at a time. Suppose the input + // addends are : <a1, x>, <b1, y>, <a2, x>, <c1, z>, <b2, y>, ... + // The symbolic-values will be processed in this order: x, y, z. + for (unsigned SymIdx = 0; SymIdx < AddendNum; SymIdx++) { + + const FAddend *ThisAddend = Addends[SymIdx]; + if (!ThisAddend) { + // This addend was processed before. + continue; + } + + Value *Val = ThisAddend->getSymVal(); + + // If the resulting expr has constant-addend, this constant-addend is + // desirable to reside at the top of the resulting expression tree. Placing + // constant close to super-expr(s) will potentially reveal some + // optimization opportunities in super-expr(s). Here we do not implement + // this logic intentionally and rely on SimplifyAssociativeOrCommutative + // call later. + + unsigned StartIdx = SimpVect.size(); + SimpVect.push_back(ThisAddend); + + // The inner loop collects addends sharing same symbolic-value, and these + // addends will be later on folded into a single addend. Following above + // example, if the symbolic value "y" is being processed, the inner loop + // will collect two addends "<b1,y>" and "<b2,Y>". These two addends will + // be later on folded into "<b1+b2, y>". + for (unsigned SameSymIdx = SymIdx + 1; + SameSymIdx < AddendNum; SameSymIdx++) { + const FAddend *T = Addends[SameSymIdx]; + if (T && T->getSymVal() == Val) { + // Set null such that next iteration of the outer loop will not process + // this addend again. + Addends[SameSymIdx] = nullptr; + SimpVect.push_back(T); + } + } + + // If multiple addends share same symbolic value, fold them together. + if (StartIdx + 1 != SimpVect.size()) { + FAddend &R = TmpResult[NextTmpIdx ++]; + R = *SimpVect[StartIdx]; + for (unsigned Idx = StartIdx + 1; Idx < SimpVect.size(); Idx++) + R += *SimpVect[Idx]; + + // Pop all addends being folded and push the resulting folded addend. + SimpVect.resize(StartIdx); + if (!R.isZero()) { + SimpVect.push_back(&R); + } + } + } + + assert((NextTmpIdx <= array_lengthof(TmpResult) + 1) && + "out-of-bound access"); + + Value *Result; + if (!SimpVect.empty()) + Result = createNaryFAdd(SimpVect, InstrQuota); + else { + // The addition is folded to 0.0. + Result = ConstantFP::get(Instr->getType(), 0.0); + } + + return Result; +} + +Value *FAddCombine::createNaryFAdd + (const AddendVect &Opnds, unsigned InstrQuota) { + assert(!Opnds.empty() && "Expect at least one addend"); + + // Step 1: Check if the # of instructions needed exceeds the quota. + + unsigned InstrNeeded = calcInstrNumber(Opnds); + if (InstrNeeded > InstrQuota) + return nullptr; + + initCreateInstNum(); + + // step 2: Emit the N-ary addition. + // Note that at most three instructions are involved in Fadd-InstCombine: the + // addition in question, and at most two neighboring instructions. + // The resulting optimized addition should have at least one less instruction + // than the original addition expression tree. This implies that the resulting + // N-ary addition has at most two instructions, and we don't need to worry + // about tree-height when constructing the N-ary addition. + + Value *LastVal = nullptr; + bool LastValNeedNeg = false; + + // Iterate the addends, creating fadd/fsub using adjacent two addends. + for (const FAddend *Opnd : Opnds) { + bool NeedNeg; + Value *V = createAddendVal(*Opnd, NeedNeg); + if (!LastVal) { + LastVal = V; + LastValNeedNeg = NeedNeg; + continue; + } + + if (LastValNeedNeg == NeedNeg) { + LastVal = createFAdd(LastVal, V); + continue; + } + + if (LastValNeedNeg) + LastVal = createFSub(V, LastVal); + else + LastVal = createFSub(LastVal, V); + + LastValNeedNeg = false; + } + + if (LastValNeedNeg) { + LastVal = createFNeg(LastVal); + } + +#ifndef NDEBUG + assert(CreateInstrNum == InstrNeeded && + "Inconsistent in instruction numbers"); +#endif + + return LastVal; +} + +Value *FAddCombine::createFSub(Value *Opnd0, Value *Opnd1) { + Value *V = Builder.CreateFSub(Opnd0, Opnd1); + if (Instruction *I = dyn_cast<Instruction>(V)) + createInstPostProc(I); + return V; +} + +Value *FAddCombine::createFNeg(Value *V) { + Value *NewV = Builder.CreateFNeg(V); + if (Instruction *I = dyn_cast<Instruction>(NewV)) + createInstPostProc(I, true); // fneg's don't receive instruction numbers. + return NewV; +} + +Value *FAddCombine::createFAdd(Value *Opnd0, Value *Opnd1) { + Value *V = Builder.CreateFAdd(Opnd0, Opnd1); + if (Instruction *I = dyn_cast<Instruction>(V)) + createInstPostProc(I); + return V; +} + +Value *FAddCombine::createFMul(Value *Opnd0, Value *Opnd1) { + Value *V = Builder.CreateFMul(Opnd0, Opnd1); + if (Instruction *I = dyn_cast<Instruction>(V)) + createInstPostProc(I); + return V; +} + +void FAddCombine::createInstPostProc(Instruction *NewInstr, bool NoNumber) { + NewInstr->setDebugLoc(Instr->getDebugLoc()); + + // Keep track of the number of instruction created. + if (!NoNumber) + incCreateInstNum(); + + // Propagate fast-math flags + NewInstr->setFastMathFlags(Instr->getFastMathFlags()); +} + +// Return the number of instruction needed to emit the N-ary addition. +// NOTE: Keep this function in sync with createAddendVal(). +unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) { + unsigned OpndNum = Opnds.size(); + unsigned InstrNeeded = OpndNum - 1; + + // Adjust the number of instructions needed to emit the N-ary add. + for (const FAddend *Opnd : Opnds) { + if (Opnd->isConstant()) + continue; + + // The constant check above is really for a few special constant + // coefficients. + if (isa<UndefValue>(Opnd->getSymVal())) + continue; + + const FAddendCoef &CE = Opnd->getCoef(); + // Let the addend be "c * x". If "c == +/-1", the value of the addend + // is immediately available; otherwise, it needs exactly one instruction + // to evaluate the value. + if (!CE.isMinusOne() && !CE.isOne()) + InstrNeeded++; + } + return InstrNeeded; +} + +// Input Addend Value NeedNeg(output) +// ================================================================ +// Constant C C false +// <+/-1, V> V coefficient is -1 +// <2/-2, V> "fadd V, V" coefficient is -2 +// <C, V> "fmul V, C" false +// +// NOTE: Keep this function in sync with FAddCombine::calcInstrNumber. +Value *FAddCombine::createAddendVal(const FAddend &Opnd, bool &NeedNeg) { + const FAddendCoef &Coeff = Opnd.getCoef(); + + if (Opnd.isConstant()) { + NeedNeg = false; + return Coeff.getValue(Instr->getType()); + } + + Value *OpndVal = Opnd.getSymVal(); + + if (Coeff.isMinusOne() || Coeff.isOne()) { + NeedNeg = Coeff.isMinusOne(); + return OpndVal; + } + + if (Coeff.isTwo() || Coeff.isMinusTwo()) { + NeedNeg = Coeff.isMinusTwo(); + return createFAdd(OpndVal, OpndVal); + } + + NeedNeg = false; + return createFMul(OpndVal, Coeff.getValue(Instr->getType())); +} + +// Checks if any operand is negative and we can convert add to sub. +// This function checks for following negative patterns +// ADD(XOR(OR(Z, NOT(C)), C)), 1) == NEG(AND(Z, C)) +// ADD(XOR(AND(Z, C), C), 1) == NEG(OR(Z, ~C)) +// XOR(AND(Z, C), (C + 1)) == NEG(OR(Z, ~C)) if C is even +static Value *checkForNegativeOperand(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + + // This function creates 2 instructions to replace ADD, we need at least one + // of LHS or RHS to have one use to ensure benefit in transform. + if (!LHS->hasOneUse() && !RHS->hasOneUse()) + return nullptr; + + Value *X = nullptr, *Y = nullptr, *Z = nullptr; + const APInt *C1 = nullptr, *C2 = nullptr; + + // if ONE is on other side, swap + if (match(RHS, m_Add(m_Value(X), m_One()))) + std::swap(LHS, RHS); + + if (match(LHS, m_Add(m_Value(X), m_One()))) { + // if XOR on other side, swap + if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1)))) + std::swap(X, RHS); + + if (match(X, m_Xor(m_Value(Y), m_APInt(C1)))) { + // X = XOR(Y, C1), Y = OR(Z, C2), C2 = NOT(C1) ==> X == NOT(AND(Z, C1)) + // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, AND(Z, C1)) + if (match(Y, m_Or(m_Value(Z), m_APInt(C2))) && (*C2 == ~(*C1))) { + Value *NewAnd = Builder.CreateAnd(Z, *C1); + return Builder.CreateSub(RHS, NewAnd, "sub"); + } else if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && (*C1 == *C2)) { + // X = XOR(Y, C1), Y = AND(Z, C2), C2 == C1 ==> X == NOT(OR(Z, ~C1)) + // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, OR(Z, ~C1)) + Value *NewOr = Builder.CreateOr(Z, ~(*C1)); + return Builder.CreateSub(RHS, NewOr, "sub"); + } + } + } + + // Restore LHS and RHS + LHS = I.getOperand(0); + RHS = I.getOperand(1); + + // if XOR is on other side, swap + if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1)))) + std::swap(LHS, RHS); + + // C2 is ODD + // LHS = XOR(Y, C1), Y = AND(Z, C2), C1 == (C2 + 1) => LHS == NEG(OR(Z, ~C2)) + // ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2)) + if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1)))) + if (C1->countTrailingZeros() == 0) + if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) { + Value *NewOr = Builder.CreateOr(Z, ~(*C2)); + return Builder.CreateSub(RHS, NewOr, "sub"); + } + return nullptr; +} + +/// Wrapping flags may allow combining constants separated by an extend. +static Instruction *foldNoWrapAdd(BinaryOperator &Add, + InstCombiner::BuilderTy &Builder) { + Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1); + Type *Ty = Add.getType(); + Constant *Op1C; + if (!match(Op1, m_Constant(Op1C))) + return nullptr; + + // Try this match first because it results in an add in the narrow type. + // (zext (X +nuw C2)) + C1 --> zext (X + (C2 + trunc(C1))) + Value *X; + const APInt *C1, *C2; + if (match(Op1, m_APInt(C1)) && + match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C2))))) && + C1->isNegative() && C1->sge(-C2->sext(C1->getBitWidth()))) { + Constant *NewC = + ConstantInt::get(X->getType(), *C2 + C1->trunc(C2->getBitWidth())); + return new ZExtInst(Builder.CreateNUWAdd(X, NewC), Ty); + } + + // More general combining of constants in the wide type. + // (sext (X +nsw NarrowC)) + C --> (sext X) + (sext(NarrowC) + C) + Constant *NarrowC; + if (match(Op0, m_OneUse(m_SExt(m_NSWAdd(m_Value(X), m_Constant(NarrowC)))))) { + Constant *WideC = ConstantExpr::getSExt(NarrowC, Ty); + Constant *NewC = ConstantExpr::getAdd(WideC, Op1C); + Value *WideX = Builder.CreateSExt(X, Ty); + return BinaryOperator::CreateAdd(WideX, NewC); + } + // (zext (X +nuw NarrowC)) + C --> (zext X) + (zext(NarrowC) + C) + if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_Constant(NarrowC)))))) { + Constant *WideC = ConstantExpr::getZExt(NarrowC, Ty); + Constant *NewC = ConstantExpr::getAdd(WideC, Op1C); + Value *WideX = Builder.CreateZExt(X, Ty); + return BinaryOperator::CreateAdd(WideX, NewC); + } + + return nullptr; +} + +Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { + Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1); + Constant *Op1C; + if (!match(Op1, m_ImmConstant(Op1C))) + return nullptr; + + if (Instruction *NV = foldBinOpIntoSelectOrPhi(Add)) + return NV; + + Value *X; + Constant *Op00C; + + // add (sub C1, X), C2 --> sub (add C1, C2), X + if (match(Op0, m_Sub(m_Constant(Op00C), m_Value(X)))) + return BinaryOperator::CreateSub(ConstantExpr::getAdd(Op00C, Op1C), X); + + Value *Y; + + // add (sub X, Y), -1 --> add (not Y), X + if (match(Op0, m_OneUse(m_Sub(m_Value(X), m_Value(Y)))) && + match(Op1, m_AllOnes())) + return BinaryOperator::CreateAdd(Builder.CreateNot(Y), X); + + // zext(bool) + C -> bool ? C + 1 : C + if (match(Op0, m_ZExt(m_Value(X))) && + X->getType()->getScalarSizeInBits() == 1) + return SelectInst::Create(X, InstCombiner::AddOne(Op1C), Op1); + // sext(bool) + C -> bool ? C - 1 : C + if (match(Op0, m_SExt(m_Value(X))) && + X->getType()->getScalarSizeInBits() == 1) + return SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1); + + // ~X + C --> (C-1) - X + if (match(Op0, m_Not(m_Value(X)))) + return BinaryOperator::CreateSub(InstCombiner::SubOne(Op1C), X); + + const APInt *C; + if (!match(Op1, m_APInt(C))) + return nullptr; + + // (X | Op01C) + Op1C --> X + (Op01C + Op1C) iff the `or` is actually an `add` + Constant *Op01C; + if (match(Op0, m_Or(m_Value(X), m_ImmConstant(Op01C))) && + haveNoCommonBitsSet(X, Op01C, DL, &AC, &Add, &DT)) + return BinaryOperator::CreateAdd(X, ConstantExpr::getAdd(Op01C, Op1C)); + + // (X | C2) + C --> (X | C2) ^ C2 iff (C2 == -C) + const APInt *C2; + if (match(Op0, m_Or(m_Value(), m_APInt(C2))) && *C2 == -*C) + return BinaryOperator::CreateXor(Op0, ConstantInt::get(Add.getType(), *C2)); + + if (C->isSignMask()) { + // If wrapping is not allowed, then the addition must set the sign bit: + // X + (signmask) --> X | signmask + if (Add.hasNoSignedWrap() || Add.hasNoUnsignedWrap()) + return BinaryOperator::CreateOr(Op0, Op1); + + // If wrapping is allowed, then the addition flips the sign bit of LHS: + // X + (signmask) --> X ^ signmask + return BinaryOperator::CreateXor(Op0, Op1); + } + + // Is this add the last step in a convoluted sext? + // add(zext(xor i16 X, -32768), -32768) --> sext X + Type *Ty = Add.getType(); + if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) && + C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C) + return CastInst::Create(Instruction::SExt, X, Ty); + + if (match(Op0, m_Xor(m_Value(X), m_APInt(C2)))) { + // (X ^ signmask) + C --> (X + (signmask ^ C)) + if (C2->isSignMask()) + return BinaryOperator::CreateAdd(X, ConstantInt::get(Ty, *C2 ^ *C)); + + // If X has no high-bits set above an xor mask: + // add (xor X, LowMaskC), C --> sub (LowMaskC + C), X + if (C2->isMask()) { + KnownBits LHSKnown = computeKnownBits(X, 0, &Add); + if ((*C2 | LHSKnown.Zero).isAllOnes()) + return BinaryOperator::CreateSub(ConstantInt::get(Ty, *C2 + *C), X); + } + + // Look for a math+logic pattern that corresponds to sext-in-register of a + // value with cleared high bits. Convert that into a pair of shifts: + // add (xor X, 0x80), 0xF..F80 --> (X << ShAmtC) >>s ShAmtC + // add (xor X, 0xF..F80), 0x80 --> (X << ShAmtC) >>s ShAmtC + if (Op0->hasOneUse() && *C2 == -(*C)) { + unsigned BitWidth = Ty->getScalarSizeInBits(); + unsigned ShAmt = 0; + if (C->isPowerOf2()) + ShAmt = BitWidth - C->logBase2() - 1; + else if (C2->isPowerOf2()) + ShAmt = BitWidth - C2->logBase2() - 1; + if (ShAmt && MaskedValueIsZero(X, APInt::getHighBitsSet(BitWidth, ShAmt), + 0, &Add)) { + Constant *ShAmtC = ConstantInt::get(Ty, ShAmt); + Value *NewShl = Builder.CreateShl(X, ShAmtC, "sext"); + return BinaryOperator::CreateAShr(NewShl, ShAmtC); + } + } + } + + if (C->isOne() && Op0->hasOneUse()) { + // add (sext i1 X), 1 --> zext (not X) + // TODO: The smallest IR representation is (select X, 0, 1), and that would + // not require the one-use check. But we need to remove a transform in + // visitSelect and make sure that IR value tracking for select is equal or + // better than for these ops. + if (match(Op0, m_SExt(m_Value(X))) && + X->getType()->getScalarSizeInBits() == 1) + return new ZExtInst(Builder.CreateNot(X), Ty); + + // Shifts and add used to flip and mask off the low bit: + // add (ashr (shl i32 X, 31), 31), 1 --> and (not X), 1 + const APInt *C3; + if (match(Op0, m_AShr(m_Shl(m_Value(X), m_APInt(C2)), m_APInt(C3))) && + C2 == C3 && *C2 == Ty->getScalarSizeInBits() - 1) { + Value *NotX = Builder.CreateNot(X); + return BinaryOperator::CreateAnd(NotX, ConstantInt::get(Ty, 1)); + } + } + + // If all bits affected by the add are included in a high-bit-mask, do the + // add before the mask op: + // (X & 0xFF00) + xx00 --> (X + xx00) & 0xFF00 + if (match(Op0, m_OneUse(m_And(m_Value(X), m_APInt(C2)))) && + C2->isNegative() && C2->isShiftedMask() && *C == (*C & *C2)) { + Value *NewAdd = Builder.CreateAdd(X, ConstantInt::get(Ty, *C)); + return BinaryOperator::CreateAnd(NewAdd, ConstantInt::get(Ty, *C2)); + } + + return nullptr; +} + +// Matches multiplication expression Op * C where C is a constant. Returns the +// constant value in C and the other operand in Op. Returns true if such a +// match is found. +static bool MatchMul(Value *E, Value *&Op, APInt &C) { + const APInt *AI; + if (match(E, m_Mul(m_Value(Op), m_APInt(AI)))) { + C = *AI; + return true; + } + if (match(E, m_Shl(m_Value(Op), m_APInt(AI)))) { + C = APInt(AI->getBitWidth(), 1); + C <<= *AI; + return true; + } + return false; +} + +// Matches remainder expression Op % C where C is a constant. Returns the +// constant value in C and the other operand in Op. Returns the signedness of +// the remainder operation in IsSigned. Returns true if such a match is +// found. +static bool MatchRem(Value *E, Value *&Op, APInt &C, bool &IsSigned) { + const APInt *AI; + IsSigned = false; + if (match(E, m_SRem(m_Value(Op), m_APInt(AI)))) { + IsSigned = true; + C = *AI; + return true; + } + if (match(E, m_URem(m_Value(Op), m_APInt(AI)))) { + C = *AI; + return true; + } + if (match(E, m_And(m_Value(Op), m_APInt(AI))) && (*AI + 1).isPowerOf2()) { + C = *AI + 1; + return true; + } + return false; +} + +// Matches division expression Op / C with the given signedness as indicated +// by IsSigned, where C is a constant. Returns the constant value in C and the +// other operand in Op. Returns true if such a match is found. +static bool MatchDiv(Value *E, Value *&Op, APInt &C, bool IsSigned) { + const APInt *AI; + if (IsSigned && match(E, m_SDiv(m_Value(Op), m_APInt(AI)))) { + C = *AI; + return true; + } + if (!IsSigned) { + if (match(E, m_UDiv(m_Value(Op), m_APInt(AI)))) { + C = *AI; + return true; + } + if (match(E, m_LShr(m_Value(Op), m_APInt(AI)))) { + C = APInt(AI->getBitWidth(), 1); + C <<= *AI; + return true; + } + } + return false; +} + +// Returns whether C0 * C1 with the given signedness overflows. +static bool MulWillOverflow(APInt &C0, APInt &C1, bool IsSigned) { + bool overflow; + if (IsSigned) + (void)C0.smul_ov(C1, overflow); + else + (void)C0.umul_ov(C1, overflow); + return overflow; +} + +// Simplifies X % C0 + (( X / C0 ) % C1) * C0 to X % (C0 * C1), where (C0 * C1) +// does not overflow. +Value *InstCombinerImpl::SimplifyAddWithRemainder(BinaryOperator &I) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Value *X, *MulOpV; + APInt C0, MulOpC; + bool IsSigned; + // Match I = X % C0 + MulOpV * C0 + if (((MatchRem(LHS, X, C0, IsSigned) && MatchMul(RHS, MulOpV, MulOpC)) || + (MatchRem(RHS, X, C0, IsSigned) && MatchMul(LHS, MulOpV, MulOpC))) && + C0 == MulOpC) { + Value *RemOpV; + APInt C1; + bool Rem2IsSigned; + // Match MulOpC = RemOpV % C1 + if (MatchRem(MulOpV, RemOpV, C1, Rem2IsSigned) && + IsSigned == Rem2IsSigned) { + Value *DivOpV; + APInt DivOpC; + // Match RemOpV = X / C0 + if (MatchDiv(RemOpV, DivOpV, DivOpC, IsSigned) && X == DivOpV && + C0 == DivOpC && !MulWillOverflow(C0, C1, IsSigned)) { + Value *NewDivisor = ConstantInt::get(X->getType(), C0 * C1); + return IsSigned ? Builder.CreateSRem(X, NewDivisor, "srem") + : Builder.CreateURem(X, NewDivisor, "urem"); + } + } + } + + return nullptr; +} + +/// Fold +/// (1 << NBits) - 1 +/// Into: +/// ~(-(1 << NBits)) +/// Because a 'not' is better for bit-tracking analysis and other transforms +/// than an 'add'. The new shl is always nsw, and is nuw if old `and` was. +static Instruction *canonicalizeLowbitMask(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *NBits; + if (!match(&I, m_Add(m_OneUse(m_Shl(m_One(), m_Value(NBits))), m_AllOnes()))) + return nullptr; + + Constant *MinusOne = Constant::getAllOnesValue(NBits->getType()); + Value *NotMask = Builder.CreateShl(MinusOne, NBits, "notmask"); + // Be wary of constant folding. + if (auto *BOp = dyn_cast<BinaryOperator>(NotMask)) { + // Always NSW. But NUW propagates from `add`. + BOp->setHasNoSignedWrap(); + BOp->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + } + + return BinaryOperator::CreateNot(NotMask, I.getName()); +} + +static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) { + assert(I.getOpcode() == Instruction::Add && "Expecting add instruction"); + Type *Ty = I.getType(); + auto getUAddSat = [&]() { + return Intrinsic::getDeclaration(I.getModule(), Intrinsic::uadd_sat, Ty); + }; + + // add (umin X, ~Y), Y --> uaddsat X, Y + Value *X, *Y; + if (match(&I, m_c_Add(m_c_UMin(m_Value(X), m_Not(m_Value(Y))), + m_Deferred(Y)))) + return CallInst::Create(getUAddSat(), { X, Y }); + + // add (umin X, ~C), C --> uaddsat X, C + const APInt *C, *NotC; + if (match(&I, m_Add(m_UMin(m_Value(X), m_APInt(NotC)), m_APInt(C))) && + *C == ~*NotC) + return CallInst::Create(getUAddSat(), { X, ConstantInt::get(Ty, *C) }); + + return nullptr; +} + +Instruction *InstCombinerImpl:: + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( + BinaryOperator &I) { + assert((I.getOpcode() == Instruction::Add || + I.getOpcode() == Instruction::Or || + I.getOpcode() == Instruction::Sub) && + "Expecting add/or/sub instruction"); + + // We have a subtraction/addition between a (potentially truncated) *logical* + // right-shift of X and a "select". + Value *X, *Select; + Instruction *LowBitsToSkip, *Extract; + if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_CombineAnd( + m_LShr(m_Value(X), m_Instruction(LowBitsToSkip)), + m_Instruction(Extract))), + m_Value(Select)))) + return nullptr; + + // `add`/`or` is commutative; but for `sub`, "select" *must* be on RHS. + if (I.getOpcode() == Instruction::Sub && I.getOperand(1) != Select) + return nullptr; + + Type *XTy = X->getType(); + bool HadTrunc = I.getType() != XTy; + + // If there was a truncation of extracted value, then we'll need to produce + // one extra instruction, so we need to ensure one instruction will go away. + if (HadTrunc && !match(&I, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + return nullptr; + + // Extraction should extract high NBits bits, with shift amount calculated as: + // low bits to skip = shift bitwidth - high bits to extract + // The shift amount itself may be extended, and we need to look past zero-ext + // when matching NBits, that will matter for matching later. + Constant *C; + Value *NBits; + if (!match( + LowBitsToSkip, + m_ZExtOrSelf(m_Sub(m_Constant(C), m_ZExtOrSelf(m_Value(NBits))))) || + !match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(C->getType()->getScalarSizeInBits(), + X->getType()->getScalarSizeInBits())))) + return nullptr; + + // Sign-extending value can be zero-extended if we `sub`tract it, + // or sign-extended otherwise. + auto SkipExtInMagic = [&I](Value *&V) { + if (I.getOpcode() == Instruction::Sub) + match(V, m_ZExtOrSelf(m_Value(V))); + else + match(V, m_SExtOrSelf(m_Value(V))); + }; + + // Now, finally validate the sign-extending magic. + // `select` itself may be appropriately extended, look past that. + SkipExtInMagic(Select); + + ICmpInst::Predicate Pred; + const APInt *Thr; + Value *SignExtendingValue, *Zero; + bool ShouldSignext; + // It must be a select between two values we will later establish to be a + // sign-extending value and a zero constant. The condition guarding the + // sign-extension must be based on a sign bit of the same X we had in `lshr`. + if (!match(Select, m_Select(m_ICmp(Pred, m_Specific(X), m_APInt(Thr)), + m_Value(SignExtendingValue), m_Value(Zero))) || + !isSignBitCheck(Pred, *Thr, ShouldSignext)) + return nullptr; + + // icmp-select pair is commutative. + if (!ShouldSignext) + std::swap(SignExtendingValue, Zero); + + // If we should not perform sign-extension then we must add/or/subtract zero. + if (!match(Zero, m_Zero())) + return nullptr; + // Otherwise, it should be some constant, left-shifted by the same NBits we + // had in `lshr`. Said left-shift can also be appropriately extended. + // Again, we must look past zero-ext when looking for NBits. + SkipExtInMagic(SignExtendingValue); + Constant *SignExtendingValueBaseConstant; + if (!match(SignExtendingValue, + m_Shl(m_Constant(SignExtendingValueBaseConstant), + m_ZExtOrSelf(m_Specific(NBits))))) + return nullptr; + // If we `sub`, then the constant should be one, else it should be all-ones. + if (I.getOpcode() == Instruction::Sub + ? !match(SignExtendingValueBaseConstant, m_One()) + : !match(SignExtendingValueBaseConstant, m_AllOnes())) + return nullptr; + + auto *NewAShr = BinaryOperator::CreateAShr(X, LowBitsToSkip, + Extract->getName() + ".sext"); + NewAShr->copyIRFlags(Extract); // Preserve `exact`-ness. + if (!HadTrunc) + return NewAShr; + + Builder.Insert(NewAShr); + return TruncInst::CreateTruncOrBitCast(NewAShr, I.getType()); +} + +/// This is a specialization of a more general transform from +/// SimplifyUsingDistributiveLaws. If that code can be made to work optimally +/// for multi-use cases or propagating nsw/nuw, then we would not need this. +static Instruction *factorizeMathWithShlOps(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + // TODO: Also handle mul by doubling the shift amount? + assert((I.getOpcode() == Instruction::Add || + I.getOpcode() == Instruction::Sub) && + "Expected add/sub"); + auto *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0)); + auto *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1)); + if (!Op0 || !Op1 || !(Op0->hasOneUse() || Op1->hasOneUse())) + return nullptr; + + Value *X, *Y, *ShAmt; + if (!match(Op0, m_Shl(m_Value(X), m_Value(ShAmt))) || + !match(Op1, m_Shl(m_Value(Y), m_Specific(ShAmt)))) + return nullptr; + + // No-wrap propagates only when all ops have no-wrap. + bool HasNSW = I.hasNoSignedWrap() && Op0->hasNoSignedWrap() && + Op1->hasNoSignedWrap(); + bool HasNUW = I.hasNoUnsignedWrap() && Op0->hasNoUnsignedWrap() && + Op1->hasNoUnsignedWrap(); + + // add/sub (X << ShAmt), (Y << ShAmt) --> (add/sub X, Y) << ShAmt + Value *NewMath = Builder.CreateBinOp(I.getOpcode(), X, Y); + if (auto *NewI = dyn_cast<BinaryOperator>(NewMath)) { + NewI->setHasNoSignedWrap(HasNSW); + NewI->setHasNoUnsignedWrap(HasNUW); + } + auto *NewShl = BinaryOperator::CreateShl(NewMath, ShAmt); + NewShl->setHasNoSignedWrap(HasNSW); + NewShl->setHasNoUnsignedWrap(HasNUW); + return NewShl; +} + +Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { + if (Value *V = simplifyAddInst(I.getOperand(0), I.getOperand(1), + I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + // (A*B)+(A*C) -> A*(B+C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return replaceInstUsesWith(I, V); + + if (Instruction *R = factorizeMathWithShlOps(I, Builder)) + return R; + + if (Instruction *X = foldAddWithConstant(I)) + return X; + + if (Instruction *X = foldNoWrapAdd(I, Builder)) + return X; + + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Type *Ty = I.getType(); + if (Ty->isIntOrIntVectorTy(1)) + return BinaryOperator::CreateXor(LHS, RHS); + + // X + X --> X << 1 + if (LHS == RHS) { + auto *Shl = BinaryOperator::CreateShl(LHS, ConstantInt::get(Ty, 1)); + Shl->setHasNoSignedWrap(I.hasNoSignedWrap()); + Shl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + return Shl; + } + + Value *A, *B; + if (match(LHS, m_Neg(m_Value(A)))) { + // -A + -B --> -(A + B) + if (match(RHS, m_Neg(m_Value(B)))) + return BinaryOperator::CreateNeg(Builder.CreateAdd(A, B)); + + // -A + B --> B - A + return BinaryOperator::CreateSub(RHS, A); + } + + // A + -B --> A - B + if (match(RHS, m_Neg(m_Value(B)))) + return BinaryOperator::CreateSub(LHS, B); + + if (Value *V = checkForNegativeOperand(I, Builder)) + return replaceInstUsesWith(I, V); + + // (A + 1) + ~B --> A - B + // ~B + (A + 1) --> A - B + // (~B + A) + 1 --> A - B + // (A + ~B) + 1 --> A - B + if (match(&I, m_c_BinOp(m_Add(m_Value(A), m_One()), m_Not(m_Value(B)))) || + match(&I, m_BinOp(m_c_Add(m_Not(m_Value(B)), m_Value(A)), m_One()))) + return BinaryOperator::CreateSub(A, B); + + // (A + RHS) + RHS --> A + (RHS << 1) + if (match(LHS, m_OneUse(m_c_Add(m_Value(A), m_Specific(RHS))))) + return BinaryOperator::CreateAdd(A, Builder.CreateShl(RHS, 1, "reass.add")); + + // LHS + (A + LHS) --> A + (LHS << 1) + if (match(RHS, m_OneUse(m_c_Add(m_Value(A), m_Specific(LHS))))) + return BinaryOperator::CreateAdd(A, Builder.CreateShl(LHS, 1, "reass.add")); + + { + // (A + C1) + (C2 - B) --> (A - B) + (C1 + C2) + Constant *C1, *C2; + if (match(&I, m_c_Add(m_Add(m_Value(A), m_ImmConstant(C1)), + m_Sub(m_ImmConstant(C2), m_Value(B)))) && + (LHS->hasOneUse() || RHS->hasOneUse())) { + Value *Sub = Builder.CreateSub(A, B); + return BinaryOperator::CreateAdd(Sub, ConstantExpr::getAdd(C1, C2)); + } + } + + // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1) + if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V); + + // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2 + const APInt *C1, *C2; + if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) { + APInt one(C2->getBitWidth(), 1); + APInt minusC1 = -(*C1); + if (minusC1 == (one << *C2)) { + Constant *NewRHS = ConstantInt::get(RHS->getType(), minusC1); + return BinaryOperator::CreateSRem(RHS, NewRHS); + } + } + + // (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit + if (match(&I, m_c_Add(m_And(m_Value(A), m_APInt(C1)), m_Deferred(A))) && + C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countLeadingZeros())) { + Constant *NewMask = ConstantInt::get(RHS->getType(), *C1 - 1); + return BinaryOperator::CreateAnd(A, NewMask); + } + + // A+B --> A|B iff A and B have no bits set in common. + if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) + return BinaryOperator::CreateOr(LHS, RHS); + + // add (select X 0 (sub n A)) A --> select X A n + { + SelectInst *SI = dyn_cast<SelectInst>(LHS); + Value *A = RHS; + if (!SI) { + SI = dyn_cast<SelectInst>(RHS); + A = LHS; + } + if (SI && SI->hasOneUse()) { + Value *TV = SI->getTrueValue(); + Value *FV = SI->getFalseValue(); + Value *N; + + // Can we fold the add into the argument of the select? + // We check both true and false select arguments for a matching subtract. + if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A)))) + // Fold the add into the true select value. + return SelectInst::Create(SI->getCondition(), N, A); + + if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A)))) + // Fold the add into the false select value. + return SelectInst::Create(SI->getCondition(), A, N); + } + } + + if (Instruction *Ext = narrowMathIfNoOverflow(I)) + return Ext; + + // (add (xor A, B) (and A, B)) --> (or A, B) + // (add (and A, B) (xor A, B)) --> (or A, B) + if (match(&I, m_c_BinOp(m_Xor(m_Value(A), m_Value(B)), + m_c_And(m_Deferred(A), m_Deferred(B))))) + return BinaryOperator::CreateOr(A, B); + + // (add (or A, B) (and A, B)) --> (add A, B) + // (add (and A, B) (or A, B)) --> (add A, B) + if (match(&I, m_c_BinOp(m_Or(m_Value(A), m_Value(B)), + m_c_And(m_Deferred(A), m_Deferred(B))))) { + // Replacing operands in-place to preserve nuw/nsw flags. + replaceOperand(I, 0, A); + replaceOperand(I, 1, B); + return &I; + } + + // TODO(jingyue): Consider willNotOverflowSignedAdd and + // willNotOverflowUnsignedAdd to reduce the number of invocations of + // computeKnownBits. + bool Changed = false; + if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) { + Changed = true; + I.setHasNoUnsignedWrap(true); + } + + if (Instruction *V = canonicalizeLowbitMask(I, Builder)) + return V; + + if (Instruction *V = + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) + return V; + + if (Instruction *SatAdd = foldToUnsignedSaturatedAdd(I)) + return SatAdd; + + // usub.sat(A, B) + B => umax(A, B) + if (match(&I, m_c_BinOp( + m_OneUse(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A), m_Value(B))), + m_Deferred(B)))) { + return replaceInstUsesWith(I, + Builder.CreateIntrinsic(Intrinsic::umax, {I.getType()}, {A, B})); + } + + // ctpop(A) + ctpop(B) => ctpop(A | B) if A and B have no bits set in common. + if (match(LHS, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(A)))) && + match(RHS, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(B)))) && + haveNoCommonBitsSet(A, B, DL, &AC, &I, &DT)) + return replaceInstUsesWith( + I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, + {Builder.CreateOr(A, B)})); + + return Changed ? &I : nullptr; +} + +/// Eliminate an op from a linear interpolation (lerp) pattern. +static Instruction *factorizeLerp(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *X, *Y, *Z; + if (!match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_Value(Y), + m_OneUse(m_FSub(m_FPOne(), + m_Value(Z))))), + m_OneUse(m_c_FMul(m_Value(X), m_Deferred(Z)))))) + return nullptr; + + // (Y * (1.0 - Z)) + (X * Z) --> Y + Z * (X - Y) [8 commuted variants] + Value *XY = Builder.CreateFSubFMF(X, Y, &I); + Value *MulZ = Builder.CreateFMulFMF(Z, XY, &I); + return BinaryOperator::CreateFAddFMF(Y, MulZ, &I); +} + +/// Factor a common operand out of fadd/fsub of fmul/fdiv. +static Instruction *factorizeFAddFSub(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert((I.getOpcode() == Instruction::FAdd || + I.getOpcode() == Instruction::FSub) && "Expecting fadd/fsub"); + assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && + "FP factorization requires FMF"); + + if (Instruction *Lerp = factorizeLerp(I, Builder)) + return Lerp; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (!Op0->hasOneUse() || !Op1->hasOneUse()) + return nullptr; + + Value *X, *Y, *Z; + bool IsFMul; + if ((match(Op0, m_FMul(m_Value(X), m_Value(Z))) && + match(Op1, m_c_FMul(m_Value(Y), m_Specific(Z)))) || + (match(Op0, m_FMul(m_Value(Z), m_Value(X))) && + match(Op1, m_c_FMul(m_Value(Y), m_Specific(Z))))) + IsFMul = true; + else if (match(Op0, m_FDiv(m_Value(X), m_Value(Z))) && + match(Op1, m_FDiv(m_Value(Y), m_Specific(Z)))) + IsFMul = false; + else + return nullptr; + + // (X * Z) + (Y * Z) --> (X + Y) * Z + // (X * Z) - (Y * Z) --> (X - Y) * Z + // (X / Z) + (Y / Z) --> (X + Y) / Z + // (X / Z) - (Y / Z) --> (X - Y) / Z + bool IsFAdd = I.getOpcode() == Instruction::FAdd; + Value *XY = IsFAdd ? Builder.CreateFAddFMF(X, Y, &I) + : Builder.CreateFSubFMF(X, Y, &I); + + // Bail out if we just created a denormal constant. + // TODO: This is copied from a previous implementation. Is it necessary? + const APFloat *C; + if (match(XY, m_APFloat(C)) && !C->isNormal()) + return nullptr; + + return IsFMul ? BinaryOperator::CreateFMulFMF(XY, Z, &I) + : BinaryOperator::CreateFDivFMF(XY, Z, &I); +} + +Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { + if (Value *V = simplifyFAddInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + if (Instruction *FoldedFAdd = foldBinOpIntoSelectOrPhi(I)) + return FoldedFAdd; + + // (-X) + Y --> Y - X + Value *X, *Y; + if (match(&I, m_c_FAdd(m_FNeg(m_Value(X)), m_Value(Y)))) + return BinaryOperator::CreateFSubFMF(Y, X, &I); + + // Similar to above, but look through fmul/fdiv for the negated term. + // (-X * Y) + Z --> Z - (X * Y) [4 commuted variants] + Value *Z; + if (match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))), + m_Value(Z)))) { + Value *XY = Builder.CreateFMulFMF(X, Y, &I); + return BinaryOperator::CreateFSubFMF(Z, XY, &I); + } + // (-X / Y) + Z --> Z - (X / Y) [2 commuted variants] + // (X / -Y) + Z --> Z - (X / Y) [2 commuted variants] + if (match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y))), + m_Value(Z))) || + match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))), + m_Value(Z)))) { + Value *XY = Builder.CreateFDivFMF(X, Y, &I); + return BinaryOperator::CreateFSubFMF(Z, XY, &I); + } + + // Check for (fadd double (sitofp x), y), see if we can merge this into an + // integer add followed by a promotion. + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) { + Value *LHSIntVal = LHSConv->getOperand(0); + Type *FPType = LHSConv->getType(); + + // TODO: This check is overly conservative. In many cases known bits + // analysis can tell us that the result of the addition has less significant + // bits than the integer type can hold. + auto IsValidPromotion = [](Type *FTy, Type *ITy) { + Type *FScalarTy = FTy->getScalarType(); + Type *IScalarTy = ITy->getScalarType(); + + // Do we have enough bits in the significand to represent the result of + // the integer addition? + unsigned MaxRepresentableBits = + APFloat::semanticsPrecision(FScalarTy->getFltSemantics()); + return IScalarTy->getIntegerBitWidth() <= MaxRepresentableBits; + }; + + // (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst)) + // ... if the constant fits in the integer value. This is useful for things + // like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer + // requires a constant pool load, and generally allows the add to be better + // instcombined. + if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) + if (IsValidPromotion(FPType, LHSIntVal->getType())) { + Constant *CI = + ConstantExpr::getFPToSI(CFP, LHSIntVal->getType()); + if (LHSConv->hasOneUse() && + ConstantExpr::getSIToFP(CI, I.getType()) == CFP && + willNotOverflowSignedAdd(LHSIntVal, CI, I)) { + // Insert the new integer add. + Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI, "addconv"); + return new SIToFPInst(NewAdd, I.getType()); + } + } + + // (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y)) + if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) { + Value *RHSIntVal = RHSConv->getOperand(0); + // It's enough to check LHS types only because we require int types to + // be the same for this transform. + if (IsValidPromotion(FPType, LHSIntVal->getType())) { + // Only do this if x/y have the same type, if at least one of them has a + // single use (so we don't increase the number of int->fp conversions), + // and if the integer add will not overflow. + if (LHSIntVal->getType() == RHSIntVal->getType() && + (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && + willNotOverflowSignedAdd(LHSIntVal, RHSIntVal, I)) { + // Insert the new integer add. + Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, RHSIntVal, "addconv"); + return new SIToFPInst(NewAdd, I.getType()); + } + } + } + } + + // Handle specials cases for FAdd with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS)) + return replaceInstUsesWith(I, V); + + if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { + if (Instruction *F = factorizeFAddFSub(I, Builder)) + return F; + + // Try to fold fadd into start value of reduction intrinsic. + if (match(&I, m_c_FAdd(m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_fadd>( + m_AnyZeroFP(), m_Value(X))), + m_Value(Y)))) { + // fadd (rdx 0.0, X), Y --> rdx Y, X + return replaceInstUsesWith( + I, Builder.CreateIntrinsic(Intrinsic::vector_reduce_fadd, + {X->getType()}, {Y, X}, &I)); + } + const APFloat *StartC, *C; + if (match(LHS, m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_fadd>( + m_APFloat(StartC), m_Value(X)))) && + match(RHS, m_APFloat(C))) { + // fadd (rdx StartC, X), C --> rdx (C + StartC), X + Constant *NewStartC = ConstantFP::get(I.getType(), *C + *StartC); + return replaceInstUsesWith( + I, Builder.CreateIntrinsic(Intrinsic::vector_reduce_fadd, + {X->getType()}, {NewStartC, X}, &I)); + } + + // (X * MulC) + X --> X * (MulC + 1.0) + Constant *MulC; + if (match(&I, m_c_FAdd(m_FMul(m_Value(X), m_ImmConstant(MulC)), + m_Deferred(X)))) { + MulC = ConstantExpr::getFAdd(MulC, ConstantFP::get(I.getType(), 1.0)); + return BinaryOperator::CreateFMulFMF(X, MulC, &I); + } + + if (Value *V = FAddCombine(Builder).simplify(&I)) + return replaceInstUsesWith(I, V); + } + + return nullptr; +} + +/// Optimize pointer differences into the same array into a size. Consider: +/// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer +/// operands to the ptrtoint instructions for the LHS/RHS of the subtract. +Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS, + Type *Ty, bool IsNUW) { + // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize + // this. + bool Swapped = false; + GEPOperator *GEP1 = nullptr, *GEP2 = nullptr; + if (!isa<GEPOperator>(LHS) && isa<GEPOperator>(RHS)) { + std::swap(LHS, RHS); + Swapped = true; + } + + // Require at least one GEP with a common base pointer on both sides. + if (auto *LHSGEP = dyn_cast<GEPOperator>(LHS)) { + // (gep X, ...) - X + if (LHSGEP->getOperand(0)->stripPointerCasts() == + RHS->stripPointerCasts()) { + GEP1 = LHSGEP; + } else if (auto *RHSGEP = dyn_cast<GEPOperator>(RHS)) { + // (gep X, ...) - (gep X, ...) + if (LHSGEP->getOperand(0)->stripPointerCasts() == + RHSGEP->getOperand(0)->stripPointerCasts()) { + GEP1 = LHSGEP; + GEP2 = RHSGEP; + } + } + } + + if (!GEP1) + return nullptr; + + if (GEP2) { + // (gep X, ...) - (gep X, ...) + // + // Avoid duplicating the arithmetic if there are more than one non-constant + // indices between the two GEPs and either GEP has a non-constant index and + // multiple users. If zero non-constant index, the result is a constant and + // there is no duplication. If one non-constant index, the result is an add + // or sub with a constant, which is no larger than the original code, and + // there's no duplicated arithmetic, even if either GEP has multiple + // users. If more than one non-constant indices combined, as long as the GEP + // with at least one non-constant index doesn't have multiple users, there + // is no duplication. + unsigned NumNonConstantIndices1 = GEP1->countNonConstantIndices(); + unsigned NumNonConstantIndices2 = GEP2->countNonConstantIndices(); + if (NumNonConstantIndices1 + NumNonConstantIndices2 > 1 && + ((NumNonConstantIndices1 > 0 && !GEP1->hasOneUse()) || + (NumNonConstantIndices2 > 0 && !GEP2->hasOneUse()))) { + return nullptr; + } + } + + // Emit the offset of the GEP and an intptr_t. + Value *Result = EmitGEPOffset(GEP1); + + // If this is a single inbounds GEP and the original sub was nuw, + // then the final multiplication is also nuw. + if (auto *I = dyn_cast<Instruction>(Result)) + if (IsNUW && !GEP2 && !Swapped && GEP1->isInBounds() && + I->getOpcode() == Instruction::Mul) + I->setHasNoUnsignedWrap(); + + // If we have a 2nd GEP of the same base pointer, subtract the offsets. + // If both GEPs are inbounds, then the subtract does not have signed overflow. + if (GEP2) { + Value *Offset = EmitGEPOffset(GEP2); + Result = Builder.CreateSub(Result, Offset, "gepdiff", /* NUW */ false, + GEP1->isInBounds() && GEP2->isInBounds()); + } + + // If we have p - gep(p, ...) then we have to negate the result. + if (Swapped) + Result = Builder.CreateNeg(Result, "diff.neg"); + + return Builder.CreateIntCast(Result, Ty, true); +} + +Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { + if (Value *V = simplifySubInst(I.getOperand(0), I.getOperand(1), + I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // If this is a 'B = x-(-A)', change to B = x+A. + // We deal with this without involving Negator to preserve NSW flag. + if (Value *V = dyn_castNegVal(Op1)) { + BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V); + + if (const auto *BO = dyn_cast<BinaryOperator>(Op1)) { + assert(BO->getOpcode() == Instruction::Sub && + "Expected a subtraction operator!"); + if (BO->hasNoSignedWrap() && I.hasNoSignedWrap()) + Res->setHasNoSignedWrap(true); + } else { + if (cast<Constant>(Op1)->isNotMinSignedValue() && I.hasNoSignedWrap()) + Res->setHasNoSignedWrap(true); + } + + return Res; + } + + // Try this before Negator to preserve NSW flag. + if (Instruction *R = factorizeMathWithShlOps(I, Builder)) + return R; + + Constant *C; + if (match(Op0, m_ImmConstant(C))) { + Value *X; + Constant *C2; + + // C-(X+C2) --> (C-C2)-X + if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2)))) + return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); + } + + auto TryToNarrowDeduceFlags = [this, &I, &Op0, &Op1]() -> Instruction * { + if (Instruction *Ext = narrowMathIfNoOverflow(I)) + return Ext; + + bool Changed = false; + if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedSub(Op0, Op1, I)) { + Changed = true; + I.setHasNoUnsignedWrap(true); + } + + return Changed ? &I : nullptr; + }; + + // First, let's try to interpret `sub a, b` as `add a, (sub 0, b)`, + // and let's try to sink `(sub 0, b)` into `b` itself. But only if this isn't + // a pure negation used by a select that looks like abs/nabs. + bool IsNegation = match(Op0, m_ZeroInt()); + if (!IsNegation || none_of(I.users(), [&I, Op1](const User *U) { + const Instruction *UI = dyn_cast<Instruction>(U); + if (!UI) + return false; + return match(UI, + m_Select(m_Value(), m_Specific(Op1), m_Specific(&I))) || + match(UI, m_Select(m_Value(), m_Specific(&I), m_Specific(Op1))); + })) { + if (Value *NegOp1 = Negator::Negate(IsNegation, Op1, *this)) + return BinaryOperator::CreateAdd(NegOp1, Op0); + } + if (IsNegation) + return TryToNarrowDeduceFlags(); // Should have been handled in Negator! + + // (A*B)-(A*C) -> A*(B-C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return replaceInstUsesWith(I, V); + + if (I.getType()->isIntOrIntVectorTy(1)) + return BinaryOperator::CreateXor(Op0, Op1); + + // Replace (-1 - A) with (~A). + if (match(Op0, m_AllOnes())) + return BinaryOperator::CreateNot(Op1); + + // (X + -1) - Y --> ~Y + X + Value *X, *Y; + if (match(Op0, m_OneUse(m_Add(m_Value(X), m_AllOnes())))) + return BinaryOperator::CreateAdd(Builder.CreateNot(Op1), X); + + // Reassociate sub/add sequences to create more add instructions and + // reduce dependency chains: + // ((X - Y) + Z) - Op1 --> (X + Z) - (Y + Op1) + Value *Z; + if (match(Op0, m_OneUse(m_c_Add(m_OneUse(m_Sub(m_Value(X), m_Value(Y))), + m_Value(Z))))) { + Value *XZ = Builder.CreateAdd(X, Z); + Value *YW = Builder.CreateAdd(Y, Op1); + return BinaryOperator::CreateSub(XZ, YW); + } + + // ((X - Y) - Op1) --> X - (Y + Op1) + if (match(Op0, m_OneUse(m_Sub(m_Value(X), m_Value(Y))))) { + Value *Add = Builder.CreateAdd(Y, Op1); + return BinaryOperator::CreateSub(X, Add); + } + + // (~X) - (~Y) --> Y - X + // This is placed after the other reassociations and explicitly excludes a + // sub-of-sub pattern to avoid infinite looping. + if (isFreeToInvert(Op0, Op0->hasOneUse()) && + isFreeToInvert(Op1, Op1->hasOneUse()) && + !match(Op0, m_Sub(m_ImmConstant(), m_Value()))) { + Value *NotOp0 = Builder.CreateNot(Op0); + Value *NotOp1 = Builder.CreateNot(Op1); + return BinaryOperator::CreateSub(NotOp1, NotOp0); + } + + auto m_AddRdx = [](Value *&Vec) { + return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec))); + }; + Value *V0, *V1; + if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) && + V0->getType() == V1->getType()) { + // Difference of sums is sum of differences: + // add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1) + Value *Sub = Builder.CreateSub(V0, V1); + Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add, + {Sub->getType()}, {Sub}); + return replaceInstUsesWith(I, Rdx); + } + + if (Constant *C = dyn_cast<Constant>(Op0)) { + Value *X; + if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) + // C - (zext bool) --> bool ? C - 1 : C + return SelectInst::Create(X, InstCombiner::SubOne(C), C); + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) + // C - (sext bool) --> bool ? C + 1 : C + return SelectInst::Create(X, InstCombiner::AddOne(C), C); + + // C - ~X == X + (1+C) + if (match(Op1, m_Not(m_Value(X)))) + return BinaryOperator::CreateAdd(X, InstCombiner::AddOne(C)); + + // Try to fold constant sub into select arguments. + if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) + if (Instruction *R = FoldOpIntoSelect(I, SI)) + return R; + + // Try to fold constant sub into PHI values. + if (PHINode *PN = dyn_cast<PHINode>(Op1)) + if (Instruction *R = foldOpIntoPhi(I, PN)) + return R; + + Constant *C2; + + // C-(C2-X) --> X+(C-C2) + if (match(Op1, m_Sub(m_ImmConstant(C2), m_Value(X)))) + return BinaryOperator::CreateAdd(X, ConstantExpr::getSub(C, C2)); + } + + const APInt *Op0C; + if (match(Op0, m_APInt(Op0C)) && Op0C->isMask()) { + // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known + // zero. + KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); + if ((*Op0C | RHSKnown.Zero).isAllOnes()) + return BinaryOperator::CreateXor(Op1, Op0); + } + + { + Value *Y; + // X-(X+Y) == -Y X-(Y+X) == -Y + if (match(Op1, m_c_Add(m_Specific(Op0), m_Value(Y)))) + return BinaryOperator::CreateNeg(Y); + + // (X-Y)-X == -Y + if (match(Op0, m_Sub(m_Specific(Op1), m_Value(Y)))) + return BinaryOperator::CreateNeg(Y); + } + + // (sub (or A, B) (and A, B)) --> (xor A, B) + { + Value *A, *B; + if (match(Op1, m_And(m_Value(A), m_Value(B))) && + match(Op0, m_c_Or(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateXor(A, B); + } + + // (sub (add A, B) (or A, B)) --> (and A, B) + { + Value *A, *B; + if (match(Op0, m_Add(m_Value(A), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateAnd(A, B); + } + + // (sub (add A, B) (and A, B)) --> (or A, B) + { + Value *A, *B; + if (match(Op0, m_Add(m_Value(A), m_Value(B))) && + match(Op1, m_c_And(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); + } + + // (sub (and A, B) (or A, B)) --> neg (xor A, B) + { + Value *A, *B; + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) && + (Op0->hasOneUse() || Op1->hasOneUse())) + return BinaryOperator::CreateNeg(Builder.CreateXor(A, B)); + } + + // (sub (or A, B), (xor A, B)) --> (and A, B) + { + Value *A, *B; + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && + match(Op0, m_c_Or(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateAnd(A, B); + } + + // (sub (xor A, B) (or A, B)) --> neg (and A, B) + { + Value *A, *B; + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) && + (Op0->hasOneUse() || Op1->hasOneUse())) + return BinaryOperator::CreateNeg(Builder.CreateAnd(A, B)); + } + + { + Value *Y; + // ((X | Y) - X) --> (~X & Y) + if (match(Op0, m_OneUse(m_c_Or(m_Value(Y), m_Specific(Op1))))) + return BinaryOperator::CreateAnd( + Y, Builder.CreateNot(Op1, Op1->getName() + ".not")); + } + + { + // (sub (and Op1, (neg X)), Op1) --> neg (and Op1, (add X, -1)) + Value *X; + if (match(Op0, m_OneUse(m_c_And(m_Specific(Op1), + m_OneUse(m_Neg(m_Value(X))))))) { + return BinaryOperator::CreateNeg(Builder.CreateAnd( + Op1, Builder.CreateAdd(X, Constant::getAllOnesValue(I.getType())))); + } + } + + { + // (sub (and Op1, C), Op1) --> neg (and Op1, ~C) + Constant *C; + if (match(Op0, m_OneUse(m_And(m_Specific(Op1), m_Constant(C))))) { + return BinaryOperator::CreateNeg( + Builder.CreateAnd(Op1, Builder.CreateNot(C))); + } + } + + if (auto *II = dyn_cast<MinMaxIntrinsic>(Op1)) { + { + // sub(add(X,Y), s/umin(X,Y)) --> s/umax(X,Y) + // sub(add(X,Y), s/umax(X,Y)) --> s/umin(X,Y) + Value *X = II->getLHS(); + Value *Y = II->getRHS(); + if (match(Op0, m_c_Add(m_Specific(X), m_Specific(Y))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, Y); + return replaceInstUsesWith(I, InvMaxMin); + } + } + + { + // sub(add(X,Y),umin(Y,Z)) --> add(X,usub.sat(Y,Z)) + // sub(add(X,Z),umin(Y,Z)) --> add(X,usub.sat(Z,Y)) + Value *X, *Y, *Z; + if (match(Op1, m_OneUse(m_UMin(m_Value(Y), m_Value(Z))))) { + if (match(Op0, m_OneUse(m_c_Add(m_Specific(Y), m_Value(X))))) + return BinaryOperator::CreateAdd( + X, Builder.CreateIntrinsic(Intrinsic::usub_sat, I.getType(), + {Y, Z})); + if (match(Op0, m_OneUse(m_c_Add(m_Specific(Z), m_Value(X))))) + return BinaryOperator::CreateAdd( + X, Builder.CreateIntrinsic(Intrinsic::usub_sat, I.getType(), + {Z, Y})); + } + } + } + + { + // If we have a subtraction between some value and a select between + // said value and something else, sink subtraction into select hands, i.e.: + // sub (select %Cond, %TrueVal, %FalseVal), %Op1 + // -> + // select %Cond, (sub %TrueVal, %Op1), (sub %FalseVal, %Op1) + // or + // sub %Op0, (select %Cond, %TrueVal, %FalseVal) + // -> + // select %Cond, (sub %Op0, %TrueVal), (sub %Op0, %FalseVal) + // This will result in select between new subtraction and 0. + auto SinkSubIntoSelect = + [Ty = I.getType()](Value *Select, Value *OtherHandOfSub, + auto SubBuilder) -> Instruction * { + Value *Cond, *TrueVal, *FalseVal; + if (!match(Select, m_OneUse(m_Select(m_Value(Cond), m_Value(TrueVal), + m_Value(FalseVal))))) + return nullptr; + if (OtherHandOfSub != TrueVal && OtherHandOfSub != FalseVal) + return nullptr; + // While it is really tempting to just create two subtractions and let + // InstCombine fold one of those to 0, it isn't possible to do so + // because of worklist visitation order. So ugly it is. + bool OtherHandOfSubIsTrueVal = OtherHandOfSub == TrueVal; + Value *NewSub = SubBuilder(OtherHandOfSubIsTrueVal ? FalseVal : TrueVal); + Constant *Zero = Constant::getNullValue(Ty); + SelectInst *NewSel = + SelectInst::Create(Cond, OtherHandOfSubIsTrueVal ? Zero : NewSub, + OtherHandOfSubIsTrueVal ? NewSub : Zero); + // Preserve prof metadata if any. + NewSel->copyMetadata(cast<Instruction>(*Select)); + return NewSel; + }; + if (Instruction *NewSel = SinkSubIntoSelect( + /*Select=*/Op0, /*OtherHandOfSub=*/Op1, + [Builder = &Builder, Op1](Value *OtherHandOfSelect) { + return Builder->CreateSub(OtherHandOfSelect, + /*OtherHandOfSub=*/Op1); + })) + return NewSel; + if (Instruction *NewSel = SinkSubIntoSelect( + /*Select=*/Op1, /*OtherHandOfSub=*/Op0, + [Builder = &Builder, Op0](Value *OtherHandOfSelect) { + return Builder->CreateSub(/*OtherHandOfSub=*/Op0, + OtherHandOfSelect); + })) + return NewSel; + } + + // (X - (X & Y)) --> (X & ~Y) + if (match(Op1, m_c_And(m_Specific(Op0), m_Value(Y))) && + (Op1->hasOneUse() || isa<Constant>(Y))) + return BinaryOperator::CreateAnd( + Op0, Builder.CreateNot(Y, Y->getName() + ".not")); + + // ~X - Min/Max(~X, Y) -> ~Min/Max(X, ~Y) - X + // ~X - Min/Max(Y, ~X) -> ~Min/Max(X, ~Y) - X + // Min/Max(~X, Y) - ~X -> X - ~Min/Max(X, ~Y) + // Min/Max(Y, ~X) - ~X -> X - ~Min/Max(X, ~Y) + // As long as Y is freely invertible, this will be neutral or a win. + // Note: We don't generate the inverse max/min, just create the 'not' of + // it and let other folds do the rest. + if (match(Op0, m_Not(m_Value(X))) && + match(Op1, m_c_MaxOrMin(m_Specific(Op0), m_Value(Y))) && + !Op0->hasNUsesOrMore(3) && isFreeToInvert(Y, Y->hasOneUse())) { + Value *Not = Builder.CreateNot(Op1); + return BinaryOperator::CreateSub(Not, X); + } + if (match(Op1, m_Not(m_Value(X))) && + match(Op0, m_c_MaxOrMin(m_Specific(Op1), m_Value(Y))) && + !Op1->hasNUsesOrMore(3) && isFreeToInvert(Y, Y->hasOneUse())) { + Value *Not = Builder.CreateNot(Op0); + return BinaryOperator::CreateSub(X, Not); + } + + // Optimize pointer differences into the same array into a size. Consider: + // &A[10] - &A[0]: we should compile this to "10". + Value *LHSOp, *RHSOp; + if (match(Op0, m_PtrToInt(m_Value(LHSOp))) && + match(Op1, m_PtrToInt(m_Value(RHSOp)))) + if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), + I.hasNoUnsignedWrap())) + return replaceInstUsesWith(I, Res); + + // trunc(p)-trunc(q) -> trunc(p-q) + if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) && + match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp))))) + if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), + /* IsNUW */ false)) + return replaceInstUsesWith(I, Res); + + // Canonicalize a shifty way to code absolute value to the common pattern. + // There are 2 potential commuted variants. + // We're relying on the fact that we only do this transform when the shift has + // exactly 2 uses and the xor has exactly 1 use (otherwise, we might increase + // instructions). + Value *A; + const APInt *ShAmt; + Type *Ty = I.getType(); + if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && + Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 && + match(Op0, m_OneUse(m_c_Xor(m_Specific(A), m_Specific(Op1))))) { + // B = ashr i32 A, 31 ; smear the sign bit + // sub (xor A, B), B ; flip bits if negative and subtract -1 (add 1) + // --> (A < 0) ? -A : A + Value *IsNeg = Builder.CreateIsNeg(A); + // Copy the nuw/nsw flags from the sub to the negate. + Value *NegA = Builder.CreateNeg(A, "", I.hasNoUnsignedWrap(), + I.hasNoSignedWrap()); + return SelectInst::Create(IsNeg, NegA, A); + } + + // If we are subtracting a low-bit masked subset of some value from an add + // of that same value with no low bits changed, that is clearing some low bits + // of the sum: + // sub (X + AddC), (X & AndC) --> and (X + AddC), ~AndC + const APInt *AddC, *AndC; + if (match(Op0, m_Add(m_Value(X), m_APInt(AddC))) && + match(Op1, m_And(m_Specific(X), m_APInt(AndC)))) { + unsigned BitWidth = Ty->getScalarSizeInBits(); + unsigned Cttz = AddC->countTrailingZeros(); + APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz)); + if ((HighMask & *AndC).isZero()) + return BinaryOperator::CreateAnd(Op0, ConstantInt::get(Ty, ~(*AndC))); + } + + if (Instruction *V = + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) + return V; + + // X - usub.sat(X, Y) => umin(X, Y) + if (match(Op1, m_OneUse(m_Intrinsic<Intrinsic::usub_sat>(m_Specific(Op0), + m_Value(Y))))) + return replaceInstUsesWith( + I, Builder.CreateIntrinsic(Intrinsic::umin, {I.getType()}, {Op0, Y})); + + // umax(X, Op1) - Op1 --> usub.sat(X, Op1) + // TODO: The one-use restriction is not strictly necessary, but it may + // require improving other pattern matching and/or codegen. + if (match(Op0, m_OneUse(m_c_UMax(m_Value(X), m_Specific(Op1))))) + return replaceInstUsesWith( + I, Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {X, Op1})); + + // Op0 - umin(X, Op0) --> usub.sat(Op0, X) + if (match(Op1, m_OneUse(m_c_UMin(m_Value(X), m_Specific(Op0))))) + return replaceInstUsesWith( + I, Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {Op0, X})); + + // Op0 - umax(X, Op0) --> 0 - usub.sat(X, Op0) + if (match(Op1, m_OneUse(m_c_UMax(m_Value(X), m_Specific(Op0))))) { + Value *USub = Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {X, Op0}); + return BinaryOperator::CreateNeg(USub); + } + + // umin(X, Op1) - Op1 --> 0 - usub.sat(Op1, X) + if (match(Op0, m_OneUse(m_c_UMin(m_Value(X), m_Specific(Op1))))) { + Value *USub = Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {Op1, X}); + return BinaryOperator::CreateNeg(USub); + } + + // C - ctpop(X) => ctpop(~X) if C is bitwidth + if (match(Op0, m_SpecificInt(Ty->getScalarSizeInBits())) && + match(Op1, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))) + return replaceInstUsesWith( + I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, + {Builder.CreateNot(X)})); + + return TryToNarrowDeduceFlags(); +} + +/// This eliminates floating-point negation in either 'fneg(X)' or +/// 'fsub(-0.0, X)' form by combining into a constant operand. +static Instruction *foldFNegIntoConstant(Instruction &I) { + // This is limited with one-use because fneg is assumed better for + // reassociation and cheaper in codegen than fmul/fdiv. + // TODO: Should the m_OneUse restriction be removed? + Instruction *FNegOp; + if (!match(&I, m_FNeg(m_OneUse(m_Instruction(FNegOp))))) + return nullptr; + + Value *X; + Constant *C; + + // Fold negation into constant operand. + // -(X * C) --> X * (-C) + if (match(FNegOp, m_FMul(m_Value(X), m_Constant(C)))) + return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + // -(X / C) --> X / (-C) + if (match(FNegOp, m_FDiv(m_Value(X), m_Constant(C)))) + return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + // -(C / X) --> (-C) / X + if (match(FNegOp, m_FDiv(m_Constant(C), m_Value(X)))) { + Instruction *FDiv = + BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + + // Intersect 'nsz' and 'ninf' because those special value exceptions may not + // apply to the fdiv. Everything else propagates from the fneg. + // TODO: We could propagate nsz/ninf from fdiv alone? + FastMathFlags FMF = I.getFastMathFlags(); + FastMathFlags OpFMF = FNegOp->getFastMathFlags(); + FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros()); + FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs()); + return FDiv; + } + // With NSZ [ counter-example with -0.0: -(-0.0 + 0.0) != 0.0 + -0.0 ]: + // -(X + C) --> -X + -C --> -C - X + if (I.hasNoSignedZeros() && match(FNegOp, m_FAdd(m_Value(X), m_Constant(C)))) + return BinaryOperator::CreateFSubFMF(ConstantExpr::getFNeg(C), X, &I); + + return nullptr; +} + +static Instruction *hoistFNegAboveFMulFDiv(Instruction &I, + InstCombiner::BuilderTy &Builder) { + Value *FNeg; + if (!match(&I, m_FNeg(m_Value(FNeg)))) + return nullptr; + + Value *X, *Y; + if (match(FNeg, m_OneUse(m_FMul(m_Value(X), m_Value(Y))))) + return BinaryOperator::CreateFMulFMF(Builder.CreateFNegFMF(X, &I), Y, &I); + + if (match(FNeg, m_OneUse(m_FDiv(m_Value(X), m_Value(Y))))) + return BinaryOperator::CreateFDivFMF(Builder.CreateFNegFMF(X, &I), Y, &I); + + return nullptr; +} + +Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { + Value *Op = I.getOperand(0); + + if (Value *V = simplifyFNegInst(Op, I.getFastMathFlags(), + getSimplifyQuery().getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldFNegIntoConstant(I)) + return X; + + Value *X, *Y; + + // If we can ignore the sign of zeros: -(X - Y) --> (Y - X) + if (I.hasNoSignedZeros() && + match(Op, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) + return BinaryOperator::CreateFSubFMF(Y, X, &I); + + if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) + return R; + + // Try to eliminate fneg if at least 1 arm of the select is negated. + Value *Cond; + if (match(Op, m_OneUse(m_Select(m_Value(Cond), m_Value(X), m_Value(Y))))) { + // Unlike most transforms, this one is not safe to propagate nsz unless + // it is present on the original select. (We are conservatively intersecting + // the nsz flags from the select and root fneg instruction.) + auto propagateSelectFMF = [&](SelectInst *S, bool CommonOperand) { + S->copyFastMathFlags(&I); + if (auto *OldSel = dyn_cast<SelectInst>(Op)) + if (!OldSel->hasNoSignedZeros() && !CommonOperand && + !isGuaranteedNotToBeUndefOrPoison(OldSel->getCondition())) + S->setHasNoSignedZeros(false); + }; + // -(Cond ? -P : Y) --> Cond ? P : -Y + Value *P; + if (match(X, m_FNeg(m_Value(P)))) { + Value *NegY = Builder.CreateFNegFMF(Y, &I, Y->getName() + ".neg"); + SelectInst *NewSel = SelectInst::Create(Cond, P, NegY); + propagateSelectFMF(NewSel, P == Y); + return NewSel; + } + // -(Cond ? X : -P) --> Cond ? -X : P + if (match(Y, m_FNeg(m_Value(P)))) { + Value *NegX = Builder.CreateFNegFMF(X, &I, X->getName() + ".neg"); + SelectInst *NewSel = SelectInst::Create(Cond, NegX, P); + propagateSelectFMF(NewSel, P == X); + return NewSel; + } + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { + if (Value *V = simplifyFSubInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + getSimplifyQuery().getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + // Subtraction from -0.0 is the canonical form of fneg. + // fsub -0.0, X ==> fneg X + // fsub nsz 0.0, X ==> fneg nsz X + // + // FIXME This matcher does not respect FTZ or DAZ yet: + // fsub -0.0, Denorm ==> +-0 + // fneg Denorm ==> -Denorm + Value *Op; + if (match(&I, m_FNeg(m_Value(Op)))) + return UnaryOperator::CreateFNegFMF(Op, &I); + + if (Instruction *X = foldFNegIntoConstant(I)) + return X; + + if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) + return R; + + Value *X, *Y; + Constant *C; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X) + // Canonicalize to fadd to make analysis easier. + // This can also help codegen because fadd is commutative. + // Note that if this fsub was really an fneg, the fadd with -0.0 will get + // killed later. We still limit that particular transform with 'hasOneUse' + // because an fneg is assumed better/cheaper than a generic fsub. + if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) { + if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); + return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I); + } + } + + // (-X) - Op1 --> -(X + Op1) + if (I.hasNoSignedZeros() && !isa<ConstantExpr>(Op0) && + match(Op0, m_OneUse(m_FNeg(m_Value(X))))) { + Value *FAdd = Builder.CreateFAddFMF(X, Op1, &I); + return UnaryOperator::CreateFNegFMF(FAdd, &I); + } + + if (isa<Constant>(Op0)) + if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) + if (Instruction *NV = FoldOpIntoSelect(I, SI)) + return NV; + + // X - C --> X + (-C) + // But don't transform constant expressions because there's an inverse fold + // for X + (-Y) --> X - Y. + if (match(Op1, m_ImmConstant(C))) + return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); + + // X - (-Y) --> X + Y + if (match(Op1, m_FNeg(m_Value(Y)))) + return BinaryOperator::CreateFAddFMF(Op0, Y, &I); + + // Similar to above, but look through a cast of the negated value: + // X - (fptrunc(-Y)) --> X + fptrunc(Y) + Type *Ty = I.getType(); + if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) + return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPTrunc(Y, Ty), &I); + + // X - (fpext(-Y)) --> X + fpext(Y) + if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) + return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I); + + // Similar to above, but look through fmul/fdiv of the negated value: + // Op0 - (-X * Y) --> Op0 + (X * Y) + // Op0 - (Y * -X) --> Op0 + (X * Y) + if (match(Op1, m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))))) { + Value *FMul = Builder.CreateFMulFMF(X, Y, &I); + return BinaryOperator::CreateFAddFMF(Op0, FMul, &I); + } + // Op0 - (-X / Y) --> Op0 + (X / Y) + // Op0 - (X / -Y) --> Op0 + (X / Y) + if (match(Op1, m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y)))) || + match(Op1, m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))))) { + Value *FDiv = Builder.CreateFDivFMF(X, Y, &I); + return BinaryOperator::CreateFAddFMF(Op0, FDiv, &I); + } + + // Handle special cases for FSub with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) + return replaceInstUsesWith(I, V); + + if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { + // (Y - X) - Y --> -X + if (match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) + return UnaryOperator::CreateFNegFMF(X, &I); + + // Y - (X + Y) --> -X + // Y - (Y + X) --> -X + if (match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) + return UnaryOperator::CreateFNegFMF(X, &I); + + // (X * C) - X --> X * (C - 1.0) + if (match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { + Constant *CSubOne = ConstantExpr::getFSub(C, ConstantFP::get(Ty, 1.0)); + return BinaryOperator::CreateFMulFMF(Op1, CSubOne, &I); + } + // X - (X * C) --> X * (1.0 - C) + if (match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { + Constant *OneSubC = ConstantExpr::getFSub(ConstantFP::get(Ty, 1.0), C); + return BinaryOperator::CreateFMulFMF(Op0, OneSubC, &I); + } + + // Reassociate fsub/fadd sequences to create more fadd instructions and + // reduce dependency chains: + // ((X - Y) + Z) - Op1 --> (X + Z) - (Y + Op1) + Value *Z; + if (match(Op0, m_OneUse(m_c_FAdd(m_OneUse(m_FSub(m_Value(X), m_Value(Y))), + m_Value(Z))))) { + Value *XZ = Builder.CreateFAddFMF(X, Z, &I); + Value *YW = Builder.CreateFAddFMF(Y, Op1, &I); + return BinaryOperator::CreateFSubFMF(XZ, YW, &I); + } + + auto m_FaddRdx = [](Value *&Sum, Value *&Vec) { + return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_fadd>(m_Value(Sum), + m_Value(Vec))); + }; + Value *A0, *A1, *V0, *V1; + if (match(Op0, m_FaddRdx(A0, V0)) && match(Op1, m_FaddRdx(A1, V1)) && + V0->getType() == V1->getType()) { + // Difference of sums is sum of differences: + // add_rdx(A0, V0) - add_rdx(A1, V1) --> add_rdx(A0, V0 - V1) - A1 + Value *Sub = Builder.CreateFSubFMF(V0, V1, &I); + Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_fadd, + {Sub->getType()}, {A0, Sub}, &I); + return BinaryOperator::CreateFSubFMF(Rdx, A1, &I); + } + + if (Instruction *F = factorizeFAddFSub(I, Builder)) + return F; + + // TODO: This performs reassociative folds for FP ops. Some fraction of the + // functionality has been subsumed by simple pattern matching here and in + // InstSimplify. We should let a dedicated reassociation pass handle more + // complex pattern matching and remove this from InstCombine. + if (Value *V = FAddCombine(Builder).simplify(&I)) + return replaceInstUsesWith(I, V); + + // (X - Y) - Op1 --> X - (Y + Op1) + if (match(Op0, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + Value *FAdd = Builder.CreateFAddFMF(Y, Op1, &I); + return BinaryOperator::CreateFSubFMF(X, FAdd, &I); + } + } + + return nullptr; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp new file mode 100644 index 000000000000..ae8865651ece --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -0,0 +1,3838 @@ +//===- InstCombineAndOrXor.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the visitAnd, visitOr, and visitXor functions. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/Analysis/CmpInstAnalysis.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +#include "llvm/Transforms/Utils/Local.h" + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "instcombine" + +/// This is the complement of getICmpCode, which turns an opcode and two +/// operands into either a constant true or false, or a brand new ICmp +/// instruction. The sign is passed in to determine which kind of predicate to +/// use in the new icmp instruction. +static Value *getNewICmpValue(unsigned Code, bool Sign, Value *LHS, Value *RHS, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate NewPred; + if (Constant *TorF = getPredForICmpCode(Code, Sign, LHS->getType(), NewPred)) + return TorF; + return Builder.CreateICmp(NewPred, LHS, RHS); +} + +/// This is the complement of getFCmpCode, which turns an opcode and two +/// operands into either a FCmp instruction, or a true/false constant. +static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS, + InstCombiner::BuilderTy &Builder) { + FCmpInst::Predicate NewPred; + if (Constant *TorF = getPredForFCmpCode(Code, LHS->getType(), NewPred)) + return TorF; + return Builder.CreateFCmp(NewPred, LHS, RHS); +} + +/// Transform BITWISE_OP(BSWAP(A),BSWAP(B)) or +/// BITWISE_OP(BSWAP(A), Constant) to BSWAP(BITWISE_OP(A, B)) +/// \param I Binary operator to transform. +/// \return Pointer to node that must replace the original binary operator, or +/// null pointer if no transformation was made. +static Value *SimplifyBSwap(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.isBitwiseLogicOp() && "Unexpected opcode for bswap simplifying"); + + Value *OldLHS = I.getOperand(0); + Value *OldRHS = I.getOperand(1); + + Value *NewLHS; + if (!match(OldLHS, m_BSwap(m_Value(NewLHS)))) + return nullptr; + + Value *NewRHS; + const APInt *C; + + if (match(OldRHS, m_BSwap(m_Value(NewRHS)))) { + // OP( BSWAP(x), BSWAP(y) ) -> BSWAP( OP(x, y) ) + if (!OldLHS->hasOneUse() && !OldRHS->hasOneUse()) + return nullptr; + // NewRHS initialized by the matcher. + } else if (match(OldRHS, m_APInt(C))) { + // OP( BSWAP(x), CONSTANT ) -> BSWAP( OP(x, BSWAP(CONSTANT) ) ) + if (!OldLHS->hasOneUse()) + return nullptr; + NewRHS = ConstantInt::get(I.getType(), C->byteSwap()); + } else + return nullptr; + + Value *BinOp = Builder.CreateBinOp(I.getOpcode(), NewLHS, NewRHS); + Function *F = Intrinsic::getDeclaration(I.getModule(), Intrinsic::bswap, + I.getType()); + return Builder.CreateCall(F, BinOp); +} + +/// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise +/// (V < Lo || V >= Hi). This method expects that Lo < Hi. IsSigned indicates +/// whether to treat V, Lo, and Hi as signed or not. +Value *InstCombinerImpl::insertRangeTest(Value *V, const APInt &Lo, + const APInt &Hi, bool isSigned, + bool Inside) { + assert((isSigned ? Lo.slt(Hi) : Lo.ult(Hi)) && + "Lo is not < Hi in range emission code!"); + + Type *Ty = V->getType(); + + // V >= Min && V < Hi --> V < Hi + // V < Min || V >= Hi --> V >= Hi + ICmpInst::Predicate Pred = Inside ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; + if (isSigned ? Lo.isMinSignedValue() : Lo.isMinValue()) { + Pred = isSigned ? ICmpInst::getSignedPredicate(Pred) : Pred; + return Builder.CreateICmp(Pred, V, ConstantInt::get(Ty, Hi)); + } + + // V >= Lo && V < Hi --> V - Lo u< Hi - Lo + // V < Lo || V >= Hi --> V - Lo u>= Hi - Lo + Value *VMinusLo = + Builder.CreateSub(V, ConstantInt::get(Ty, Lo), V->getName() + ".off"); + Constant *HiMinusLo = ConstantInt::get(Ty, Hi - Lo); + return Builder.CreateICmp(Pred, VMinusLo, HiMinusLo); +} + +/// Classify (icmp eq (A & B), C) and (icmp ne (A & B), C) as matching patterns +/// that can be simplified. +/// One of A and B is considered the mask. The other is the value. This is +/// described as the "AMask" or "BMask" part of the enum. If the enum contains +/// only "Mask", then both A and B can be considered masks. If A is the mask, +/// then it was proven that (A & C) == C. This is trivial if C == A or C == 0. +/// If both A and C are constants, this proof is also easy. +/// For the following explanations, we assume that A is the mask. +/// +/// "AllOnes" declares that the comparison is true only if (A & B) == A or all +/// bits of A are set in B. +/// Example: (icmp eq (A & 3), 3) -> AMask_AllOnes +/// +/// "AllZeros" declares that the comparison is true only if (A & B) == 0 or all +/// bits of A are cleared in B. +/// Example: (icmp eq (A & 3), 0) -> Mask_AllZeroes +/// +/// "Mixed" declares that (A & B) == C and C might or might not contain any +/// number of one bits and zero bits. +/// Example: (icmp eq (A & 3), 1) -> AMask_Mixed +/// +/// "Not" means that in above descriptions "==" should be replaced by "!=". +/// Example: (icmp ne (A & 3), 3) -> AMask_NotAllOnes +/// +/// If the mask A contains a single bit, then the following is equivalent: +/// (icmp eq (A & B), A) equals (icmp ne (A & B), 0) +/// (icmp ne (A & B), A) equals (icmp eq (A & B), 0) +enum MaskedICmpType { + AMask_AllOnes = 1, + AMask_NotAllOnes = 2, + BMask_AllOnes = 4, + BMask_NotAllOnes = 8, + Mask_AllZeros = 16, + Mask_NotAllZeros = 32, + AMask_Mixed = 64, + AMask_NotMixed = 128, + BMask_Mixed = 256, + BMask_NotMixed = 512 +}; + +/// Return the set of patterns (from MaskedICmpType) that (icmp SCC (A & B), C) +/// satisfies. +static unsigned getMaskedICmpType(Value *A, Value *B, Value *C, + ICmpInst::Predicate Pred) { + const APInt *ConstA = nullptr, *ConstB = nullptr, *ConstC = nullptr; + match(A, m_APInt(ConstA)); + match(B, m_APInt(ConstB)); + match(C, m_APInt(ConstC)); + bool IsEq = (Pred == ICmpInst::ICMP_EQ); + bool IsAPow2 = ConstA && ConstA->isPowerOf2(); + bool IsBPow2 = ConstB && ConstB->isPowerOf2(); + unsigned MaskVal = 0; + if (ConstC && ConstC->isZero()) { + // if C is zero, then both A and B qualify as mask + MaskVal |= (IsEq ? (Mask_AllZeros | AMask_Mixed | BMask_Mixed) + : (Mask_NotAllZeros | AMask_NotMixed | BMask_NotMixed)); + if (IsAPow2) + MaskVal |= (IsEq ? (AMask_NotAllOnes | AMask_NotMixed) + : (AMask_AllOnes | AMask_Mixed)); + if (IsBPow2) + MaskVal |= (IsEq ? (BMask_NotAllOnes | BMask_NotMixed) + : (BMask_AllOnes | BMask_Mixed)); + return MaskVal; + } + + if (A == C) { + MaskVal |= (IsEq ? (AMask_AllOnes | AMask_Mixed) + : (AMask_NotAllOnes | AMask_NotMixed)); + if (IsAPow2) + MaskVal |= (IsEq ? (Mask_NotAllZeros | AMask_NotMixed) + : (Mask_AllZeros | AMask_Mixed)); + } else if (ConstA && ConstC && ConstC->isSubsetOf(*ConstA)) { + MaskVal |= (IsEq ? AMask_Mixed : AMask_NotMixed); + } + + if (B == C) { + MaskVal |= (IsEq ? (BMask_AllOnes | BMask_Mixed) + : (BMask_NotAllOnes | BMask_NotMixed)); + if (IsBPow2) + MaskVal |= (IsEq ? (Mask_NotAllZeros | BMask_NotMixed) + : (Mask_AllZeros | BMask_Mixed)); + } else if (ConstB && ConstC && ConstC->isSubsetOf(*ConstB)) { + MaskVal |= (IsEq ? BMask_Mixed : BMask_NotMixed); + } + + return MaskVal; +} + +/// Convert an analysis of a masked ICmp into its equivalent if all boolean +/// operations had the opposite sense. Since each "NotXXX" flag (recording !=) +/// is adjacent to the corresponding normal flag (recording ==), this just +/// involves swapping those bits over. +static unsigned conjugateICmpMask(unsigned Mask) { + unsigned NewMask; + NewMask = (Mask & (AMask_AllOnes | BMask_AllOnes | Mask_AllZeros | + AMask_Mixed | BMask_Mixed)) + << 1; + + NewMask |= (Mask & (AMask_NotAllOnes | BMask_NotAllOnes | Mask_NotAllZeros | + AMask_NotMixed | BMask_NotMixed)) + >> 1; + + return NewMask; +} + +// Adapts the external decomposeBitTestICmp for local use. +static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred, + Value *&X, Value *&Y, Value *&Z) { + APInt Mask; + if (!llvm::decomposeBitTestICmp(LHS, RHS, Pred, X, Mask)) + return false; + + Y = ConstantInt::get(X->getType(), Mask); + Z = ConstantInt::get(X->getType(), 0); + return true; +} + +/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E). +/// Return the pattern classes (from MaskedICmpType) for the left hand side and +/// the right hand side as a pair. +/// LHS and RHS are the left hand side and the right hand side ICmps and PredL +/// and PredR are their predicates, respectively. +static +Optional<std::pair<unsigned, unsigned>> +getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, + Value *&D, Value *&E, ICmpInst *LHS, + ICmpInst *RHS, + ICmpInst::Predicate &PredL, + ICmpInst::Predicate &PredR) { + // Don't allow pointers. Splat vectors are fine. + if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() || + !RHS->getOperand(0)->getType()->isIntOrIntVectorTy()) + return None; + + // Here comes the tricky part: + // LHS might be of the form L11 & L12 == X, X == L21 & L22, + // and L11 & L12 == L21 & L22. The same goes for RHS. + // Now we must find those components L** and R**, that are equal, so + // that we can extract the parameters A, B, C, D, and E for the canonical + // above. + Value *L1 = LHS->getOperand(0); + Value *L2 = LHS->getOperand(1); + Value *L11, *L12, *L21, *L22; + // Check whether the icmp can be decomposed into a bit test. + if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) { + L21 = L22 = L1 = nullptr; + } else { + // Look for ANDs in the LHS icmp. + if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) { + // Any icmp can be viewed as being trivially masked; if it allows us to + // remove one, it's worth it. + L11 = L1; + L12 = Constant::getAllOnesValue(L1->getType()); + } + + if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) { + L21 = L2; + L22 = Constant::getAllOnesValue(L2->getType()); + } + } + + // Bail if LHS was a icmp that can't be decomposed into an equality. + if (!ICmpInst::isEquality(PredL)) + return None; + + Value *R1 = RHS->getOperand(0); + Value *R2 = RHS->getOperand(1); + Value *R11, *R12; + bool Ok = false; + if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) { + if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { + A = R11; + D = R12; + } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { + A = R12; + D = R11; + } else { + return None; + } + E = R2; + R1 = nullptr; + Ok = true; + } else { + if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) { + // As before, model no mask as a trivial mask if it'll let us do an + // optimization. + R11 = R1; + R12 = Constant::getAllOnesValue(R1->getType()); + } + + if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { + A = R11; + D = R12; + E = R2; + Ok = true; + } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { + A = R12; + D = R11; + E = R2; + Ok = true; + } + } + + // Bail if RHS was a icmp that can't be decomposed into an equality. + if (!ICmpInst::isEquality(PredR)) + return None; + + // Look for ANDs on the right side of the RHS icmp. + if (!Ok) { + if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) { + R11 = R2; + R12 = Constant::getAllOnesValue(R2->getType()); + } + + if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { + A = R11; + D = R12; + E = R1; + Ok = true; + } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { + A = R12; + D = R11; + E = R1; + Ok = true; + } else { + return None; + } + + assert(Ok && "Failed to find AND on the right side of the RHS icmp."); + } + + if (L11 == A) { + B = L12; + C = L2; + } else if (L12 == A) { + B = L11; + C = L2; + } else if (L21 == A) { + B = L22; + C = L1; + } else if (L22 == A) { + B = L21; + C = L1; + } + + unsigned LeftType = getMaskedICmpType(A, B, C, PredL); + unsigned RightType = getMaskedICmpType(A, D, E, PredR); + return Optional<std::pair<unsigned, unsigned>>(std::make_pair(LeftType, RightType)); +} + +/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) into a single +/// (icmp(A & X) ==/!= Y), where the left-hand side is of type Mask_NotAllZeros +/// and the right hand side is of type BMask_Mixed. For example, +/// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8). +/// Also used for logical and/or, must be poison safe. +static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( + ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C, + Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, + InstCombiner::BuilderTy &Builder) { + // We are given the canonical form: + // (icmp ne (A & B), 0) & (icmp eq (A & D), E). + // where D & E == E. + // + // If IsAnd is false, we get it in negated form: + // (icmp eq (A & B), 0) | (icmp ne (A & D), E) -> + // !((icmp ne (A & B), 0) & (icmp eq (A & D), E)). + // + // We currently handle the case of B, C, D, E are constant. + // + const APInt *BCst, *CCst, *DCst, *OrigECst; + if (!match(B, m_APInt(BCst)) || !match(C, m_APInt(CCst)) || + !match(D, m_APInt(DCst)) || !match(E, m_APInt(OrigECst))) + return nullptr; + + ICmpInst::Predicate NewCC = IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + + // Update E to the canonical form when D is a power of two and RHS is + // canonicalized as, + // (icmp ne (A & D), 0) -> (icmp eq (A & D), D) or + // (icmp ne (A & D), D) -> (icmp eq (A & D), 0). + APInt ECst = *OrigECst; + if (PredR != NewCC) + ECst ^= *DCst; + + // If B or D is zero, skip because if LHS or RHS can be trivially folded by + // other folding rules and this pattern won't apply any more. + if (*BCst == 0 || *DCst == 0) + return nullptr; + + // If B and D don't intersect, ie. (B & D) == 0, no folding because we can't + // deduce anything from it. + // For example, + // (icmp ne (A & 12), 0) & (icmp eq (A & 3), 1) -> no folding. + if ((*BCst & *DCst) == 0) + return nullptr; + + // If the following two conditions are met: + // + // 1. mask B covers only a single bit that's not covered by mask D, that is, + // (B & (B ^ D)) is a power of 2 (in other words, B minus the intersection of + // B and D has only one bit set) and, + // + // 2. RHS (and E) indicates that the rest of B's bits are zero (in other + // words, the intersection of B and D is zero), that is, ((B & D) & E) == 0 + // + // then that single bit in B must be one and thus the whole expression can be + // folded to + // (A & (B | D)) == (B & (B ^ D)) | E. + // + // For example, + // (icmp ne (A & 12), 0) & (icmp eq (A & 7), 1) -> (icmp eq (A & 15), 9) + // (icmp ne (A & 15), 0) & (icmp eq (A & 7), 0) -> (icmp eq (A & 15), 8) + if ((((*BCst & *DCst) & ECst) == 0) && + (*BCst & (*BCst ^ *DCst)).isPowerOf2()) { + APInt BorD = *BCst | *DCst; + APInt BandBxorDorE = (*BCst & (*BCst ^ *DCst)) | ECst; + Value *NewMask = ConstantInt::get(A->getType(), BorD); + Value *NewMaskedValue = ConstantInt::get(A->getType(), BandBxorDorE); + Value *NewAnd = Builder.CreateAnd(A, NewMask); + return Builder.CreateICmp(NewCC, NewAnd, NewMaskedValue); + } + + auto IsSubSetOrEqual = [](const APInt *C1, const APInt *C2) { + return (*C1 & *C2) == *C1; + }; + auto IsSuperSetOrEqual = [](const APInt *C1, const APInt *C2) { + return (*C1 & *C2) == *C2; + }; + + // In the following, we consider only the cases where B is a superset of D, B + // is a subset of D, or B == D because otherwise there's at least one bit + // covered by B but not D, in which case we can't deduce much from it, so + // no folding (aside from the single must-be-one bit case right above.) + // For example, + // (icmp ne (A & 14), 0) & (icmp eq (A & 3), 1) -> no folding. + if (!IsSubSetOrEqual(BCst, DCst) && !IsSuperSetOrEqual(BCst, DCst)) + return nullptr; + + // At this point, either B is a superset of D, B is a subset of D or B == D. + + // If E is zero, if B is a subset of (or equal to) D, LHS and RHS contradict + // and the whole expression becomes false (or true if negated), otherwise, no + // folding. + // For example, + // (icmp ne (A & 3), 0) & (icmp eq (A & 7), 0) -> false. + // (icmp ne (A & 15), 0) & (icmp eq (A & 3), 0) -> no folding. + if (ECst.isZero()) { + if (IsSubSetOrEqual(BCst, DCst)) + return ConstantInt::get(LHS->getType(), !IsAnd); + return nullptr; + } + + // At this point, B, D, E aren't zero and (B & D) == B, (B & D) == D or B == + // D. If B is a superset of (or equal to) D, since E is not zero, LHS is + // subsumed by RHS (RHS implies LHS.) So the whole expression becomes + // RHS. For example, + // (icmp ne (A & 255), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). + // (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). + if (IsSuperSetOrEqual(BCst, DCst)) + return RHS; + // Otherwise, B is a subset of D. If B and E have a common bit set, + // ie. (B & E) != 0, then LHS is subsumed by RHS. For example. + // (icmp ne (A & 12), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). + assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code"); + if ((*BCst & ECst) != 0) + return RHS; + // Otherwise, LHS and RHS contradict and the whole expression becomes false + // (or true if negated.) For example, + // (icmp ne (A & 7), 0) & (icmp eq (A & 15), 8) -> false. + // (icmp ne (A & 6), 0) & (icmp eq (A & 15), 8) -> false. + return ConstantInt::get(LHS->getType(), !IsAnd); +} + +/// Try to fold (icmp(A & B) ==/!= 0) &/| (icmp(A & D) ==/!= E) into a single +/// (icmp(A & X) ==/!= Y), where the left-hand side and the right hand side +/// aren't of the common mask pattern type. +/// Also used for logical and/or, must be poison safe. +static Value *foldLogOpOfMaskedICmpsAsymmetric( + ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C, + Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, + unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) { + assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && + "Expected equality predicates for masked type of icmps."); + // Handle Mask_NotAllZeros-BMask_Mixed cases. + // (icmp ne/eq (A & B), C) &/| (icmp eq/ne (A & D), E), or + // (icmp eq/ne (A & B), C) &/| (icmp ne/eq (A & D), E) + // which gets swapped to + // (icmp ne/eq (A & D), E) &/| (icmp eq/ne (A & B), C). + if (!IsAnd) { + LHSMask = conjugateICmpMask(LHSMask); + RHSMask = conjugateICmpMask(RHSMask); + } + if ((LHSMask & Mask_NotAllZeros) && (RHSMask & BMask_Mixed)) { + if (Value *V = foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( + LHS, RHS, IsAnd, A, B, C, D, E, + PredL, PredR, Builder)) { + return V; + } + } else if ((LHSMask & BMask_Mixed) && (RHSMask & Mask_NotAllZeros)) { + if (Value *V = foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( + RHS, LHS, IsAnd, A, D, E, B, C, + PredR, PredL, Builder)) { + return V; + } + } + return nullptr; +} + +/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) +/// into a single (icmp(A & X) ==/!= Y). +static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, + bool IsLogical, + InstCombiner::BuilderTy &Builder) { + Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + Optional<std::pair<unsigned, unsigned>> MaskPair = + getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR); + if (!MaskPair) + return nullptr; + assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && + "Expected equality predicates for masked type of icmps."); + unsigned LHSMask = MaskPair->first; + unsigned RHSMask = MaskPair->second; + unsigned Mask = LHSMask & RHSMask; + if (Mask == 0) { + // Even if the two sides don't share a common pattern, check if folding can + // still happen. + if (Value *V = foldLogOpOfMaskedICmpsAsymmetric( + LHS, RHS, IsAnd, A, B, C, D, E, PredL, PredR, LHSMask, RHSMask, + Builder)) + return V; + return nullptr; + } + + // In full generality: + // (icmp (A & B) Op C) | (icmp (A & D) Op E) + // == ![ (icmp (A & B) !Op C) & (icmp (A & D) !Op E) ] + // + // If the latter can be converted into (icmp (A & X) Op Y) then the former is + // equivalent to (icmp (A & X) !Op Y). + // + // Therefore, we can pretend for the rest of this function that we're dealing + // with the conjunction, provided we flip the sense of any comparisons (both + // input and output). + + // In most cases we're going to produce an EQ for the "&&" case. + ICmpInst::Predicate NewCC = IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + if (!IsAnd) { + // Convert the masking analysis into its equivalent with negated + // comparisons. + Mask = conjugateICmpMask(Mask); + } + + if (Mask & Mask_AllZeros) { + // (icmp eq (A & B), 0) & (icmp eq (A & D), 0) + // -> (icmp eq (A & (B|D)), 0) + if (IsLogical && !isGuaranteedNotToBeUndefOrPoison(D)) + return nullptr; // TODO: Use freeze? + Value *NewOr = Builder.CreateOr(B, D); + Value *NewAnd = Builder.CreateAnd(A, NewOr); + // We can't use C as zero because we might actually handle + // (icmp ne (A & B), B) & (icmp ne (A & D), D) + // with B and D, having a single bit set. + Value *Zero = Constant::getNullValue(A->getType()); + return Builder.CreateICmp(NewCC, NewAnd, Zero); + } + if (Mask & BMask_AllOnes) { + // (icmp eq (A & B), B) & (icmp eq (A & D), D) + // -> (icmp eq (A & (B|D)), (B|D)) + if (IsLogical && !isGuaranteedNotToBeUndefOrPoison(D)) + return nullptr; // TODO: Use freeze? + Value *NewOr = Builder.CreateOr(B, D); + Value *NewAnd = Builder.CreateAnd(A, NewOr); + return Builder.CreateICmp(NewCC, NewAnd, NewOr); + } + if (Mask & AMask_AllOnes) { + // (icmp eq (A & B), A) & (icmp eq (A & D), A) + // -> (icmp eq (A & (B&D)), A) + if (IsLogical && !isGuaranteedNotToBeUndefOrPoison(D)) + return nullptr; // TODO: Use freeze? + Value *NewAnd1 = Builder.CreateAnd(B, D); + Value *NewAnd2 = Builder.CreateAnd(A, NewAnd1); + return Builder.CreateICmp(NewCC, NewAnd2, A); + } + + // Remaining cases assume at least that B and D are constant, and depend on + // their actual values. This isn't strictly necessary, just a "handle the + // easy cases for now" decision. + const APInt *ConstB, *ConstD; + if (!match(B, m_APInt(ConstB)) || !match(D, m_APInt(ConstD))) + return nullptr; + + if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) { + // (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and + // (icmp ne (A & B), B) & (icmp ne (A & D), D) + // -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0) + // Only valid if one of the masks is a superset of the other (check "B&D" is + // the same as either B or D). + APInt NewMask = *ConstB & *ConstD; + if (NewMask == *ConstB) + return LHS; + else if (NewMask == *ConstD) + return RHS; + } + + if (Mask & AMask_NotAllOnes) { + // (icmp ne (A & B), B) & (icmp ne (A & D), D) + // -> (icmp ne (A & B), A) or (icmp ne (A & D), A) + // Only valid if one of the masks is a superset of the other (check "B|D" is + // the same as either B or D). + APInt NewMask = *ConstB | *ConstD; + if (NewMask == *ConstB) + return LHS; + else if (NewMask == *ConstD) + return RHS; + } + + if (Mask & BMask_Mixed) { + // (icmp eq (A & B), C) & (icmp eq (A & D), E) + // We already know that B & C == C && D & E == E. + // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of + // C and E, which are shared by both the mask B and the mask D, don't + // contradict, then we can transform to + // -> (icmp eq (A & (B|D)), (C|E)) + // Currently, we only handle the case of B, C, D, and E being constant. + // We can't simply use C and E because we might actually handle + // (icmp ne (A & B), B) & (icmp eq (A & D), D) + // with B and D, having a single bit set. + const APInt *OldConstC, *OldConstE; + if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE))) + return nullptr; + + const APInt ConstC = PredL != NewCC ? *ConstB ^ *OldConstC : *OldConstC; + const APInt ConstE = PredR != NewCC ? *ConstD ^ *OldConstE : *OldConstE; + + // If there is a conflict, we should actually return a false for the + // whole construct. + if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) + return ConstantInt::get(LHS->getType(), !IsAnd); + + Value *NewOr1 = Builder.CreateOr(B, D); + Value *NewAnd = Builder.CreateAnd(A, NewOr1); + Constant *NewOr2 = ConstantInt::get(A->getType(), ConstC | ConstE); + return Builder.CreateICmp(NewCC, NewAnd, NewOr2); + } + + return nullptr; +} + +/// Try to fold a signed range checked with lower bound 0 to an unsigned icmp. +/// Example: (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n +/// If \p Inverted is true then the check is for the inverted range, e.g. +/// (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n +Value *InstCombinerImpl::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool Inverted) { + // Check the lower range comparison, e.g. x >= 0 + // InstCombine already ensured that if there is a constant it's on the RHS. + ConstantInt *RangeStart = dyn_cast<ConstantInt>(Cmp0->getOperand(1)); + if (!RangeStart) + return nullptr; + + ICmpInst::Predicate Pred0 = (Inverted ? Cmp0->getInversePredicate() : + Cmp0->getPredicate()); + + // Accept x > -1 or x >= 0 (after potentially inverting the predicate). + if (!((Pred0 == ICmpInst::ICMP_SGT && RangeStart->isMinusOne()) || + (Pred0 == ICmpInst::ICMP_SGE && RangeStart->isZero()))) + return nullptr; + + ICmpInst::Predicate Pred1 = (Inverted ? Cmp1->getInversePredicate() : + Cmp1->getPredicate()); + + Value *Input = Cmp0->getOperand(0); + Value *RangeEnd; + if (Cmp1->getOperand(0) == Input) { + // For the upper range compare we have: icmp x, n + RangeEnd = Cmp1->getOperand(1); + } else if (Cmp1->getOperand(1) == Input) { + // For the upper range compare we have: icmp n, x + RangeEnd = Cmp1->getOperand(0); + Pred1 = ICmpInst::getSwappedPredicate(Pred1); + } else { + return nullptr; + } + + // Check the upper range comparison, e.g. x < n + ICmpInst::Predicate NewPred; + switch (Pred1) { + case ICmpInst::ICMP_SLT: NewPred = ICmpInst::ICMP_ULT; break; + case ICmpInst::ICMP_SLE: NewPred = ICmpInst::ICMP_ULE; break; + default: return nullptr; + } + + // This simplification is only valid if the upper range is not negative. + KnownBits Known = computeKnownBits(RangeEnd, /*Depth=*/0, Cmp1); + if (!Known.isNonNegative()) + return nullptr; + + if (Inverted) + NewPred = ICmpInst::getInversePredicate(NewPred); + + return Builder.CreateICmp(NewPred, Input, RangeEnd); +} + +// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) +// Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) +Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, + ICmpInst *RHS, + Instruction *CxtI, + bool IsAnd, + bool IsLogical) { + CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ; + if (LHS->getPredicate() != Pred || RHS->getPredicate() != Pred) + return nullptr; + + if (!match(LHS->getOperand(1), m_Zero()) || + !match(RHS->getOperand(1), m_Zero())) + return nullptr; + + Value *L1, *L2, *R1, *R2; + if (match(LHS->getOperand(0), m_And(m_Value(L1), m_Value(L2))) && + match(RHS->getOperand(0), m_And(m_Value(R1), m_Value(R2)))) { + if (L1 == R2 || L2 == R2) + std::swap(R1, R2); + if (L2 == R1) + std::swap(L1, L2); + + if (L1 == R1 && + isKnownToBeAPowerOfTwo(L2, false, 0, CxtI) && + isKnownToBeAPowerOfTwo(R2, false, 0, CxtI)) { + // If this is a logical and/or, then we must prevent propagation of a + // poison value from the RHS by inserting freeze. + if (IsLogical) + R2 = Builder.CreateFreeze(R2); + Value *Mask = Builder.CreateOr(L2, R2); + Value *Masked = Builder.CreateAnd(L1, Mask); + auto NewPred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; + return Builder.CreateICmp(NewPred, Masked, Mask); + } + } + + return nullptr; +} + +/// General pattern: +/// X & Y +/// +/// Where Y is checking that all the high bits (covered by a mask 4294967168) +/// are uniform, i.e. %arg & 4294967168 can be either 4294967168 or 0 +/// Pattern can be one of: +/// %t = add i32 %arg, 128 +/// %r = icmp ult i32 %t, 256 +/// Or +/// %t0 = shl i32 %arg, 24 +/// %t1 = ashr i32 %t0, 24 +/// %r = icmp eq i32 %t1, %arg +/// Or +/// %t0 = trunc i32 %arg to i8 +/// %t1 = sext i8 %t0 to i32 +/// %r = icmp eq i32 %t1, %arg +/// This pattern is a signed truncation check. +/// +/// And X is checking that some bit in that same mask is zero. +/// I.e. can be one of: +/// %r = icmp sgt i32 %arg, -1 +/// Or +/// %t = and i32 %arg, 2147483648 +/// %r = icmp eq i32 %t, 0 +/// +/// Since we are checking that all the bits in that mask are the same, +/// and a particular bit is zero, what we are really checking is that all the +/// masked bits are zero. +/// So this should be transformed to: +/// %r = icmp ult i32 %arg, 128 +static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1, + Instruction &CxtI, + InstCombiner::BuilderTy &Builder) { + assert(CxtI.getOpcode() == Instruction::And); + + // Match icmp ult (add %arg, C01), C1 (C1 == C01 << 1; powers of two) + auto tryToMatchSignedTruncationCheck = [](ICmpInst *ICmp, Value *&X, + APInt &SignBitMask) -> bool { + CmpInst::Predicate Pred; + const APInt *I01, *I1; // powers of two; I1 == I01 << 1 + if (!(match(ICmp, + m_ICmp(Pred, m_Add(m_Value(X), m_Power2(I01)), m_Power2(I1))) && + Pred == ICmpInst::ICMP_ULT && I1->ugt(*I01) && I01->shl(1) == *I1)) + return false; + // Which bit is the new sign bit as per the 'signed truncation' pattern? + SignBitMask = *I01; + return true; + }; + + // One icmp needs to be 'signed truncation check'. + // We need to match this first, else we will mismatch commutative cases. + Value *X1; + APInt HighestBit; + ICmpInst *OtherICmp; + if (tryToMatchSignedTruncationCheck(ICmp1, X1, HighestBit)) + OtherICmp = ICmp0; + else if (tryToMatchSignedTruncationCheck(ICmp0, X1, HighestBit)) + OtherICmp = ICmp1; + else + return nullptr; + + assert(HighestBit.isPowerOf2() && "expected to be power of two (non-zero)"); + + // Try to match/decompose into: icmp eq (X & Mask), 0 + auto tryToDecompose = [](ICmpInst *ICmp, Value *&X, + APInt &UnsetBitsMask) -> bool { + CmpInst::Predicate Pred = ICmp->getPredicate(); + // Can it be decomposed into icmp eq (X & Mask), 0 ? + if (llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1), + Pred, X, UnsetBitsMask, + /*LookThroughTrunc=*/false) && + Pred == ICmpInst::ICMP_EQ) + return true; + // Is it icmp eq (X & Mask), 0 already? + const APInt *Mask; + if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) && + Pred == ICmpInst::ICMP_EQ) { + UnsetBitsMask = *Mask; + return true; + } + return false; + }; + + // And the other icmp needs to be decomposable into a bit test. + Value *X0; + APInt UnsetBitsMask; + if (!tryToDecompose(OtherICmp, X0, UnsetBitsMask)) + return nullptr; + + assert(!UnsetBitsMask.isZero() && "empty mask makes no sense."); + + // Are they working on the same value? + Value *X; + if (X1 == X0) { + // Ok as is. + X = X1; + } else if (match(X0, m_Trunc(m_Specific(X1)))) { + UnsetBitsMask = UnsetBitsMask.zext(X1->getType()->getScalarSizeInBits()); + X = X1; + } else + return nullptr; + + // So which bits should be uniform as per the 'signed truncation check'? + // (all the bits starting with (i.e. including) HighestBit) + APInt SignBitsMask = ~(HighestBit - 1U); + + // UnsetBitsMask must have some common bits with SignBitsMask, + if (!UnsetBitsMask.intersects(SignBitsMask)) + return nullptr; + + // Does UnsetBitsMask contain any bits outside of SignBitsMask? + if (!UnsetBitsMask.isSubsetOf(SignBitsMask)) { + APInt OtherHighestBit = (~UnsetBitsMask) + 1U; + if (!OtherHighestBit.isPowerOf2()) + return nullptr; + HighestBit = APIntOps::umin(HighestBit, OtherHighestBit); + } + // Else, if it does not, then all is ok as-is. + + // %r = icmp ult %X, SignBit + return Builder.CreateICmpULT(X, ConstantInt::get(X->getType(), HighestBit), + CxtI.getName() + ".simplified"); +} + +/// Fold (icmp eq ctpop(X) 1) | (icmp eq X 0) into (icmp ult ctpop(X) 2) and +/// fold (icmp ne ctpop(X) 1) & (icmp ne X 0) into (icmp ugt ctpop(X) 1). +/// Also used for logical and/or, must be poison safe. +static Value *foldIsPowerOf2OrZero(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd, + InstCombiner::BuilderTy &Builder) { + CmpInst::Predicate Pred0, Pred1; + Value *X; + if (!match(Cmp0, m_ICmp(Pred0, m_Intrinsic<Intrinsic::ctpop>(m_Value(X)), + m_SpecificInt(1))) || + !match(Cmp1, m_ICmp(Pred1, m_Specific(X), m_ZeroInt()))) + return nullptr; + + Value *CtPop = Cmp0->getOperand(0); + if (IsAnd && Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_NE) + return Builder.CreateICmpUGT(CtPop, ConstantInt::get(CtPop->getType(), 1)); + if (!IsAnd && Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_EQ) + return Builder.CreateICmpULT(CtPop, ConstantInt::get(CtPop->getType(), 2)); + + return nullptr; +} + +/// Reduce a pair of compares that check if a value has exactly 1 bit set. +/// Also used for logical and/or, must be poison safe. +static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, + InstCombiner::BuilderTy &Builder) { + // Handle 'and' / 'or' commutation: make the equality check the first operand. + if (JoinedByAnd && Cmp1->getPredicate() == ICmpInst::ICMP_NE) + std::swap(Cmp0, Cmp1); + else if (!JoinedByAnd && Cmp1->getPredicate() == ICmpInst::ICMP_EQ) + std::swap(Cmp0, Cmp1); + + // (X != 0) && (ctpop(X) u< 2) --> ctpop(X) == 1 + CmpInst::Predicate Pred0, Pred1; + Value *X; + if (JoinedByAnd && match(Cmp0, m_ICmp(Pred0, m_Value(X), m_ZeroInt())) && + match(Cmp1, m_ICmp(Pred1, m_Intrinsic<Intrinsic::ctpop>(m_Specific(X)), + m_SpecificInt(2))) && + Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_ULT) { + Value *CtPop = Cmp1->getOperand(0); + return Builder.CreateICmpEQ(CtPop, ConstantInt::get(CtPop->getType(), 1)); + } + // (X == 0) || (ctpop(X) u> 1) --> ctpop(X) != 1 + if (!JoinedByAnd && match(Cmp0, m_ICmp(Pred0, m_Value(X), m_ZeroInt())) && + match(Cmp1, m_ICmp(Pred1, m_Intrinsic<Intrinsic::ctpop>(m_Specific(X)), + m_SpecificInt(1))) && + Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_UGT) { + Value *CtPop = Cmp1->getOperand(0); + return Builder.CreateICmpNE(CtPop, ConstantInt::get(CtPop->getType(), 1)); + } + return nullptr; +} + +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, + ICmpInst *UnsignedICmp, bool IsAnd, + const SimplifyQuery &Q, + InstCombiner::BuilderTy &Builder) { + Value *ZeroCmpOp; + ICmpInst::Predicate EqPred; + if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(ZeroCmpOp), m_Zero())) || + !ICmpInst::isEquality(EqPred)) + return nullptr; + + auto IsKnownNonZero = [&](Value *V) { + return isKnownNonZero(V, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + }; + + ICmpInst::Predicate UnsignedPred; + + Value *A, *B; + if (match(UnsignedICmp, + m_c_ICmp(UnsignedPred, m_Specific(ZeroCmpOp), m_Value(A))) && + match(ZeroCmpOp, m_c_Add(m_Specific(A), m_Value(B))) && + (ZeroICmp->hasOneUse() || UnsignedICmp->hasOneUse())) { + auto GetKnownNonZeroAndOther = [&](Value *&NonZero, Value *&Other) { + if (!IsKnownNonZero(NonZero)) + std::swap(NonZero, Other); + return IsKnownNonZero(NonZero); + }; + + // Given ZeroCmpOp = (A + B) + // ZeroCmpOp < A && ZeroCmpOp != 0 --> (0-X) < Y iff + // ZeroCmpOp >= A || ZeroCmpOp == 0 --> (0-X) >= Y iff + // with X being the value (A/B) that is known to be non-zero, + // and Y being remaining value. + if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_NE && + IsAnd && GetKnownNonZeroAndOther(B, A)) + return Builder.CreateICmpULT(Builder.CreateNeg(B), A); + if (UnsignedPred == ICmpInst::ICMP_UGE && EqPred == ICmpInst::ICMP_EQ && + !IsAnd && GetKnownNonZeroAndOther(B, A)) + return Builder.CreateICmpUGE(Builder.CreateNeg(B), A); + } + + Value *Base, *Offset; + if (!match(ZeroCmpOp, m_Sub(m_Value(Base), m_Value(Offset)))) + return nullptr; + + if (!match(UnsignedICmp, + m_c_ICmp(UnsignedPred, m_Specific(Base), m_Specific(Offset))) || + !ICmpInst::isUnsigned(UnsignedPred)) + return nullptr; + + // Base >=/> Offset && (Base - Offset) != 0 <--> Base > Offset + // (no overflow and not null) + if ((UnsignedPred == ICmpInst::ICMP_UGE || + UnsignedPred == ICmpInst::ICMP_UGT) && + EqPred == ICmpInst::ICMP_NE && IsAnd) + return Builder.CreateICmpUGT(Base, Offset); + + // Base <=/< Offset || (Base - Offset) == 0 <--> Base <= Offset + // (overflow or null) + if ((UnsignedPred == ICmpInst::ICMP_ULE || + UnsignedPred == ICmpInst::ICMP_ULT) && + EqPred == ICmpInst::ICMP_EQ && !IsAnd) + return Builder.CreateICmpULE(Base, Offset); + + // Base <= Offset && (Base - Offset) != 0 --> Base < Offset + if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE && + IsAnd) + return Builder.CreateICmpULT(Base, Offset); + + // Base > Offset || (Base - Offset) == 0 --> Base >= Offset + if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ && + !IsAnd) + return Builder.CreateICmpUGE(Base, Offset); + + return nullptr; +} + +struct IntPart { + Value *From; + unsigned StartBit; + unsigned NumBits; +}; + +/// Match an extraction of bits from an integer. +static Optional<IntPart> matchIntPart(Value *V) { + Value *X; + if (!match(V, m_OneUse(m_Trunc(m_Value(X))))) + return None; + + unsigned NumOriginalBits = X->getType()->getScalarSizeInBits(); + unsigned NumExtractedBits = V->getType()->getScalarSizeInBits(); + Value *Y; + const APInt *Shift; + // For a trunc(lshr Y, Shift) pattern, make sure we're only extracting bits + // from Y, not any shifted-in zeroes. + if (match(X, m_OneUse(m_LShr(m_Value(Y), m_APInt(Shift)))) && + Shift->ule(NumOriginalBits - NumExtractedBits)) + return {{Y, (unsigned)Shift->getZExtValue(), NumExtractedBits}}; + return {{X, 0, NumExtractedBits}}; +} + +/// Materialize an extraction of bits from an integer in IR. +static Value *extractIntPart(const IntPart &P, IRBuilderBase &Builder) { + Value *V = P.From; + if (P.StartBit) + V = Builder.CreateLShr(V, P.StartBit); + Type *TruncTy = V->getType()->getWithNewBitWidth(P.NumBits); + if (TruncTy != V->getType()) + V = Builder.CreateTrunc(V, TruncTy); + return V; +} + +/// (icmp eq X0, Y0) & (icmp eq X1, Y1) -> icmp eq X01, Y01 +/// (icmp ne X0, Y0) | (icmp ne X1, Y1) -> icmp ne X01, Y01 +/// where X0, X1 and Y0, Y1 are adjacent parts extracted from an integer. +Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool IsAnd) { + if (!Cmp0->hasOneUse() || !Cmp1->hasOneUse()) + return nullptr; + + CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; + if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred) + return nullptr; + + Optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0)); + Optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1)); + Optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0)); + Optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1)); + if (!L0 || !R0 || !L1 || !R1) + return nullptr; + + // Make sure the LHS/RHS compare a part of the same value, possibly after + // an operand swap. + if (L0->From != L1->From || R0->From != R1->From) { + if (L0->From != R1->From || R0->From != L1->From) + return nullptr; + std::swap(L1, R1); + } + + // Make sure the extracted parts are adjacent, canonicalizing to L0/R0 being + // the low part and L1/R1 being the high part. + if (L0->StartBit + L0->NumBits != L1->StartBit || + R0->StartBit + R0->NumBits != R1->StartBit) { + if (L1->StartBit + L1->NumBits != L0->StartBit || + R1->StartBit + R1->NumBits != R0->StartBit) + return nullptr; + std::swap(L0, L1); + std::swap(R0, R1); + } + + // We can simplify to a comparison of these larger parts of the integers. + IntPart L = {L0->From, L0->StartBit, L0->NumBits + L1->NumBits}; + IntPart R = {R0->From, R0->StartBit, R0->NumBits + R1->NumBits}; + Value *LValue = extractIntPart(L, Builder); + Value *RValue = extractIntPart(R, Builder); + return Builder.CreateICmp(Pred, LValue, RValue); +} + +/// Reduce logic-of-compares with equality to a constant by substituting a +/// common operand with the constant. Callers are expected to call this with +/// Cmp0/Cmp1 switched to handle logic op commutativity. +static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool IsAnd, + InstCombiner::BuilderTy &Builder, + const SimplifyQuery &Q) { + // Match an equality compare with a non-poison constant as Cmp0. + // Also, give up if the compare can be constant-folded to avoid looping. + ICmpInst::Predicate Pred0; + Value *X; + Constant *C; + if (!match(Cmp0, m_ICmp(Pred0, m_Value(X), m_Constant(C))) || + !isGuaranteedNotToBeUndefOrPoison(C) || isa<Constant>(X)) + return nullptr; + if ((IsAnd && Pred0 != ICmpInst::ICMP_EQ) || + (!IsAnd && Pred0 != ICmpInst::ICMP_NE)) + return nullptr; + + // The other compare must include a common operand (X). Canonicalize the + // common operand as operand 1 (Pred1 is swapped if the common operand was + // operand 0). + Value *Y; + ICmpInst::Predicate Pred1; + if (!match(Cmp1, m_c_ICmp(Pred1, m_Value(Y), m_Deferred(X)))) + return nullptr; + + // Replace variable with constant value equivalence to remove a variable use: + // (X == C) && (Y Pred1 X) --> (X == C) && (Y Pred1 C) + // (X != C) || (Y Pred1 X) --> (X != C) || (Y Pred1 C) + // Can think of the 'or' substitution with the 'and' bool equivalent: + // A || B --> A || (!A && B) + Value *SubstituteCmp = simplifyICmpInst(Pred1, Y, C, Q); + if (!SubstituteCmp) { + // If we need to create a new instruction, require that the old compare can + // be removed. + if (!Cmp1->hasOneUse()) + return nullptr; + SubstituteCmp = Builder.CreateICmp(Pred1, Y, C); + } + return Builder.CreateBinOp(IsAnd ? Instruction::And : Instruction::Or, Cmp0, + SubstituteCmp); +} + +/// Fold (icmp Pred1 V1, C1) & (icmp Pred2 V2, C2) +/// or (icmp Pred1 V1, C1) | (icmp Pred2 V2, C2) +/// into a single comparison using range-based reasoning. +/// NOTE: This is also used for logical and/or, must be poison-safe! +Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, + ICmpInst *ICmp2, + bool IsAnd) { + ICmpInst::Predicate Pred1, Pred2; + Value *V1, *V2; + const APInt *C1, *C2; + if (!match(ICmp1, m_ICmp(Pred1, m_Value(V1), m_APInt(C1))) || + !match(ICmp2, m_ICmp(Pred2, m_Value(V2), m_APInt(C2)))) + return nullptr; + + // Look through add of a constant offset on V1, V2, or both operands. This + // allows us to interpret the V + C' < C'' range idiom into a proper range. + const APInt *Offset1 = nullptr, *Offset2 = nullptr; + if (V1 != V2) { + Value *X; + if (match(V1, m_Add(m_Value(X), m_APInt(Offset1)))) + V1 = X; + if (match(V2, m_Add(m_Value(X), m_APInt(Offset2)))) + V2 = X; + } + + if (V1 != V2) + return nullptr; + + ConstantRange CR1 = ConstantRange::makeExactICmpRegion( + IsAnd ? ICmpInst::getInversePredicate(Pred1) : Pred1, *C1); + if (Offset1) + CR1 = CR1.subtract(*Offset1); + + ConstantRange CR2 = ConstantRange::makeExactICmpRegion( + IsAnd ? ICmpInst::getInversePredicate(Pred2) : Pred2, *C2); + if (Offset2) + CR2 = CR2.subtract(*Offset2); + + Type *Ty = V1->getType(); + Value *NewV = V1; + Optional<ConstantRange> CR = CR1.exactUnionWith(CR2); + if (!CR) { + if (!(ICmp1->hasOneUse() && ICmp2->hasOneUse()) || CR1.isWrappedSet() || + CR2.isWrappedSet()) + return nullptr; + + // Check whether we have equal-size ranges that only differ by one bit. + // In that case we can apply a mask to map one range onto the other. + APInt LowerDiff = CR1.getLower() ^ CR2.getLower(); + APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1); + APInt CR1Size = CR1.getUpper() - CR1.getLower(); + if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff || + CR1Size != CR2.getUpper() - CR2.getLower()) + return nullptr; + + CR = CR1.getLower().ult(CR2.getLower()) ? CR1 : CR2; + NewV = Builder.CreateAnd(NewV, ConstantInt::get(Ty, ~LowerDiff)); + } + + if (IsAnd) + CR = CR->inverse(); + + CmpInst::Predicate NewPred; + APInt NewC, Offset; + CR->getEquivalentICmp(NewPred, NewC, Offset); + + if (Offset != 0) + NewV = Builder.CreateAdd(NewV, ConstantInt::get(Ty, Offset)); + return Builder.CreateICmp(NewPred, NewV, ConstantInt::get(Ty, NewC)); +} + +Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, + bool IsAnd, bool IsLogicalSelect) { + Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); + Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); + FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + + if (LHS0 == RHS1 && RHS0 == LHS1) { + // Swap RHS operands to match LHS. + PredR = FCmpInst::getSwappedPredicate(PredR); + std::swap(RHS0, RHS1); + } + + // Simplify (fcmp cc0 x, y) & (fcmp cc1 x, y). + // Suppose the relation between x and y is R, where R is one of + // U(1000), L(0100), G(0010) or E(0001), and CC0 and CC1 are the bitmasks for + // testing the desired relations. + // + // Since (R & CC0) and (R & CC1) are either R or 0, we actually have this: + // bool(R & CC0) && bool(R & CC1) + // = bool((R & CC0) & (R & CC1)) + // = bool(R & (CC0 & CC1)) <= by re-association, commutation, and idempotency + // + // Since (R & CC0) and (R & CC1) are either R or 0, we actually have this: + // bool(R & CC0) || bool(R & CC1) + // = bool((R & CC0) | (R & CC1)) + // = bool(R & (CC0 | CC1)) <= by reversed distribution (contribution? ;) + if (LHS0 == RHS0 && LHS1 == RHS1) { + unsigned FCmpCodeL = getFCmpCode(PredL); + unsigned FCmpCodeR = getFCmpCode(PredR); + unsigned NewPred = IsAnd ? FCmpCodeL & FCmpCodeR : FCmpCodeL | FCmpCodeR; + + // Intersect the fast math flags. + // TODO: We can union the fast math flags unless this is a logical select. + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + FastMathFlags FMF = LHS->getFastMathFlags(); + FMF &= RHS->getFastMathFlags(); + Builder.setFastMathFlags(FMF); + + return getFCmpValue(NewPred, LHS0, LHS1, Builder); + } + + // This transform is not valid for a logical select. + if (!IsLogicalSelect && + ((PredL == FCmpInst::FCMP_ORD && PredR == FCmpInst::FCMP_ORD && IsAnd) || + (PredL == FCmpInst::FCMP_UNO && PredR == FCmpInst::FCMP_UNO && + !IsAnd))) { + if (LHS0->getType() != RHS0->getType()) + return nullptr; + + // FCmp canonicalization ensures that (fcmp ord/uno X, X) and + // (fcmp ord/uno X, C) will be transformed to (fcmp X, +0.0). + if (match(LHS1, m_PosZeroFP()) && match(RHS1, m_PosZeroFP())) + // Ignore the constants because they are obviously not NANs: + // (fcmp ord x, 0.0) & (fcmp ord y, 0.0) -> (fcmp ord x, y) + // (fcmp uno x, 0.0) | (fcmp uno y, 0.0) -> (fcmp uno x, y) + return Builder.CreateFCmp(PredL, LHS0, RHS0); + } + + return nullptr; +} + +/// This a limited reassociation for a special case (see above) where we are +/// checking if two values are either both NAN (unordered) or not-NAN (ordered). +/// This could be handled more generally in '-reassociation', but it seems like +/// an unlikely pattern for a large number of logic ops and fcmps. +static Instruction *reassociateFCmps(BinaryOperator &BO, + InstCombiner::BuilderTy &Builder) { + Instruction::BinaryOps Opcode = BO.getOpcode(); + assert((Opcode == Instruction::And || Opcode == Instruction::Or) && + "Expecting and/or op for fcmp transform"); + + // There are 4 commuted variants of the pattern. Canonicalize operands of this + // logic op so an fcmp is operand 0 and a matching logic op is operand 1. + Value *Op0 = BO.getOperand(0), *Op1 = BO.getOperand(1), *X; + FCmpInst::Predicate Pred; + if (match(Op1, m_FCmp(Pred, m_Value(), m_AnyZeroFP()))) + std::swap(Op0, Op1); + + // Match inner binop and the predicate for combining 2 NAN checks into 1. + Value *BO10, *BO11; + FCmpInst::Predicate NanPred = Opcode == Instruction::And ? FCmpInst::FCMP_ORD + : FCmpInst::FCMP_UNO; + if (!match(Op0, m_FCmp(Pred, m_Value(X), m_AnyZeroFP())) || Pred != NanPred || + !match(Op1, m_BinOp(Opcode, m_Value(BO10), m_Value(BO11)))) + return nullptr; + + // The inner logic op must have a matching fcmp operand. + Value *Y; + if (!match(BO10, m_FCmp(Pred, m_Value(Y), m_AnyZeroFP())) || + Pred != NanPred || X->getType() != Y->getType()) + std::swap(BO10, BO11); + + if (!match(BO10, m_FCmp(Pred, m_Value(Y), m_AnyZeroFP())) || + Pred != NanPred || X->getType() != Y->getType()) + return nullptr; + + // and (fcmp ord X, 0), (and (fcmp ord Y, 0), Z) --> and (fcmp ord X, Y), Z + // or (fcmp uno X, 0), (or (fcmp uno Y, 0), Z) --> or (fcmp uno X, Y), Z + Value *NewFCmp = Builder.CreateFCmp(Pred, X, Y); + if (auto *NewFCmpInst = dyn_cast<FCmpInst>(NewFCmp)) { + // Intersect FMF from the 2 source fcmps. + NewFCmpInst->copyIRFlags(Op0); + NewFCmpInst->andIRFlags(BO10); + } + return BinaryOperator::Create(Opcode, NewFCmp, BO11); +} + +/// Match variations of De Morgan's Laws: +/// (~A & ~B) == (~(A | B)) +/// (~A | ~B) == (~(A & B)) +static Instruction *matchDeMorgansLaws(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + const Instruction::BinaryOps Opcode = I.getOpcode(); + assert((Opcode == Instruction::And || Opcode == Instruction::Or) && + "Trying to match De Morgan's Laws with something other than and/or"); + + // Flip the logic operation. + const Instruction::BinaryOps FlippedOpcode = + (Opcode == Instruction::And) ? Instruction::Or : Instruction::And; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *A, *B; + if (match(Op0, m_OneUse(m_Not(m_Value(A)))) && + match(Op1, m_OneUse(m_Not(m_Value(B)))) && + !InstCombiner::isFreeToInvert(A, A->hasOneUse()) && + !InstCombiner::isFreeToInvert(B, B->hasOneUse())) { + Value *AndOr = + Builder.CreateBinOp(FlippedOpcode, A, B, I.getName() + ".demorgan"); + return BinaryOperator::CreateNot(AndOr); + } + + // The 'not' ops may require reassociation. + // (A & ~B) & ~C --> A & ~(B | C) + // (~B & A) & ~C --> A & ~(B | C) + // (A | ~B) | ~C --> A | ~(B & C) + // (~B | A) | ~C --> A | ~(B & C) + Value *C; + if (match(Op0, m_OneUse(m_c_BinOp(Opcode, m_Value(A), m_Not(m_Value(B))))) && + match(Op1, m_Not(m_Value(C)))) { + Value *FlippedBO = Builder.CreateBinOp(FlippedOpcode, B, C); + return BinaryOperator::Create(Opcode, A, Builder.CreateNot(FlippedBO)); + } + + return nullptr; +} + +bool InstCombinerImpl::shouldOptimizeCast(CastInst *CI) { + Value *CastSrc = CI->getOperand(0); + + // Noop casts and casts of constants should be eliminated trivially. + if (CI->getSrcTy() == CI->getDestTy() || isa<Constant>(CastSrc)) + return false; + + // If this cast is paired with another cast that can be eliminated, we prefer + // to have it eliminated. + if (const auto *PrecedingCI = dyn_cast<CastInst>(CastSrc)) + if (isEliminableCastPair(PrecedingCI, CI)) + return false; + + return true; +} + +/// Fold {and,or,xor} (cast X), C. +static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, + InstCombiner::BuilderTy &Builder) { + Constant *C = dyn_cast<Constant>(Logic.getOperand(1)); + if (!C) + return nullptr; + + auto LogicOpc = Logic.getOpcode(); + Type *DestTy = Logic.getType(); + Type *SrcTy = Cast->getSrcTy(); + + // Move the logic operation ahead of a zext or sext if the constant is + // unchanged in the smaller source type. Performing the logic in a smaller + // type may provide more information to later folds, and the smaller logic + // instruction may be cheaper (particularly in the case of vectors). + Value *X; + if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) { + Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); + Constant *ZextTruncC = ConstantExpr::getZExt(TruncC, DestTy); + if (ZextTruncC == C) { + // LogicOpc (zext X), C --> zext (LogicOpc X, C) + Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC); + return new ZExtInst(NewOp, DestTy); + } + } + + if (match(Cast, m_OneUse(m_SExt(m_Value(X))))) { + Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); + Constant *SextTruncC = ConstantExpr::getSExt(TruncC, DestTy); + if (SextTruncC == C) { + // LogicOpc (sext X), C --> sext (LogicOpc X, C) + Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC); + return new SExtInst(NewOp, DestTy); + } + } + + return nullptr; +} + +/// Fold {and,or,xor} (cast X), Y. +Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) { + auto LogicOpc = I.getOpcode(); + assert(I.isBitwiseLogicOp() && "Unexpected opcode for bitwise logic folding"); + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + CastInst *Cast0 = dyn_cast<CastInst>(Op0); + if (!Cast0) + return nullptr; + + // This must be a cast from an integer or integer vector source type to allow + // transformation of the logic operation to the source type. + Type *DestTy = I.getType(); + Type *SrcTy = Cast0->getSrcTy(); + if (!SrcTy->isIntOrIntVectorTy()) + return nullptr; + + if (Instruction *Ret = foldLogicCastConstant(I, Cast0, Builder)) + return Ret; + + CastInst *Cast1 = dyn_cast<CastInst>(Op1); + if (!Cast1) + return nullptr; + + // Both operands of the logic operation are casts. The casts must be of the + // same type for reduction. + auto CastOpcode = Cast0->getOpcode(); + if (CastOpcode != Cast1->getOpcode() || SrcTy != Cast1->getSrcTy()) + return nullptr; + + Value *Cast0Src = Cast0->getOperand(0); + Value *Cast1Src = Cast1->getOperand(0); + + // fold logic(cast(A), cast(B)) -> cast(logic(A, B)) + if ((Cast0->hasOneUse() || Cast1->hasOneUse()) && + shouldOptimizeCast(Cast0) && shouldOptimizeCast(Cast1)) { + Value *NewOp = Builder.CreateBinOp(LogicOpc, Cast0Src, Cast1Src, + I.getName()); + return CastInst::Create(CastOpcode, NewOp, DestTy); + } + + // For now, only 'and'/'or' have optimizations after this. + if (LogicOpc == Instruction::Xor) + return nullptr; + + // If this is logic(cast(icmp), cast(icmp)), try to fold this even if the + // cast is otherwise not optimizable. This happens for vector sexts. + ICmpInst *ICmp0 = dyn_cast<ICmpInst>(Cast0Src); + ICmpInst *ICmp1 = dyn_cast<ICmpInst>(Cast1Src); + if (ICmp0 && ICmp1) { + if (Value *Res = + foldAndOrOfICmps(ICmp0, ICmp1, I, LogicOpc == Instruction::And)) + return CastInst::Create(CastOpcode, Res, DestTy); + return nullptr; + } + + // If this is logic(cast(fcmp), cast(fcmp)), try to fold this even if the + // cast is otherwise not optimizable. This happens for vector sexts. + FCmpInst *FCmp0 = dyn_cast<FCmpInst>(Cast0Src); + FCmpInst *FCmp1 = dyn_cast<FCmpInst>(Cast1Src); + if (FCmp0 && FCmp1) + if (Value *R = foldLogicOfFCmps(FCmp0, FCmp1, LogicOpc == Instruction::And)) + return CastInst::Create(CastOpcode, R, DestTy); + + return nullptr; +} + +static Instruction *foldAndToXor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.getOpcode() == Instruction::And); + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + Value *A, *B; + + // Operand complexity canonicalization guarantees that the 'or' is Op0. + // (A | B) & ~(A & B) --> A ^ B + // (A | B) & ~(B & A) --> A ^ B + if (match(&I, m_BinOp(m_Or(m_Value(A), m_Value(B)), + m_Not(m_c_And(m_Deferred(A), m_Deferred(B)))))) + return BinaryOperator::CreateXor(A, B); + + // (A | ~B) & (~A | B) --> ~(A ^ B) + // (A | ~B) & (B | ~A) --> ~(A ^ B) + // (~B | A) & (~A | B) --> ~(A ^ B) + // (~B | A) & (B | ~A) --> ~(A ^ B) + if (Op0->hasOneUse() || Op1->hasOneUse()) + if (match(&I, m_BinOp(m_c_Or(m_Value(A), m_Not(m_Value(B))), + m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) + return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); + + return nullptr; +} + +static Instruction *foldOrToXor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.getOpcode() == Instruction::Or); + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + Value *A, *B; + + // Operand complexity canonicalization guarantees that the 'and' is Op0. + // (A & B) | ~(A | B) --> ~(A ^ B) + // (A & B) | ~(B | A) --> ~(A ^ B) + if (Op0->hasOneUse() || Op1->hasOneUse()) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) + return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); + + // Operand complexity canonicalization guarantees that the 'xor' is Op0. + // (A ^ B) | ~(A | B) --> ~(A & B) + // (A ^ B) | ~(B | A) --> ~(A & B) + if (Op0->hasOneUse() || Op1->hasOneUse()) + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) + return BinaryOperator::CreateNot(Builder.CreateAnd(A, B)); + + // (A & ~B) | (~A & B) --> A ^ B + // (A & ~B) | (B & ~A) --> A ^ B + // (~B & A) | (~A & B) --> A ^ B + // (~B & A) | (B & ~A) --> A ^ B + if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B)))) + return BinaryOperator::CreateXor(A, B); + + return nullptr; +} + +/// Return true if a constant shift amount is always less than the specified +/// bit-width. If not, the shift could create poison in the narrower type. +static bool canNarrowShiftAmt(Constant *C, unsigned BitWidth) { + APInt Threshold(C->getType()->getScalarSizeInBits(), BitWidth); + return match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); +} + +/// Try to use narrower ops (sink zext ops) for an 'and' with binop operand and +/// a common zext operand: and (binop (zext X), C), (zext X). +Instruction *InstCombinerImpl::narrowMaskedBinOp(BinaryOperator &And) { + // This transform could also apply to {or, and, xor}, but there are better + // folds for those cases, so we don't expect those patterns here. AShr is not + // handled because it should always be transformed to LShr in this sequence. + // The subtract transform is different because it has a constant on the left. + // Add/mul commute the constant to RHS; sub with constant RHS becomes add. + Value *Op0 = And.getOperand(0), *Op1 = And.getOperand(1); + Constant *C; + if (!match(Op0, m_OneUse(m_Add(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_Mul(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_LShr(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_Shl(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_Sub(m_Constant(C), m_Specific(Op1))))) + return nullptr; + + Value *X; + if (!match(Op1, m_ZExt(m_Value(X))) || Op1->hasNUsesOrMore(3)) + return nullptr; + + Type *Ty = And.getType(); + if (!isa<VectorType>(Ty) && !shouldChangeType(Ty, X->getType())) + return nullptr; + + // If we're narrowing a shift, the shift amount must be safe (less than the + // width) in the narrower type. If the shift amount is greater, instsimplify + // usually handles that case, but we can't guarantee/assert it. + Instruction::BinaryOps Opc = cast<BinaryOperator>(Op0)->getOpcode(); + if (Opc == Instruction::LShr || Opc == Instruction::Shl) + if (!canNarrowShiftAmt(C, X->getType()->getScalarSizeInBits())) + return nullptr; + + // and (sub C, (zext X)), (zext X) --> zext (and (sub C', X), X) + // and (binop (zext X), C), (zext X) --> zext (and (binop X, C'), X) + Value *NewC = ConstantExpr::getTrunc(C, X->getType()); + Value *NewBO = Opc == Instruction::Sub ? Builder.CreateBinOp(Opc, NewC, X) + : Builder.CreateBinOp(Opc, X, NewC); + return new ZExtInst(Builder.CreateAnd(NewBO, X), Ty); +} + +/// Try folding relatively complex patterns for both And and Or operations +/// with all And and Or swapped. +static Instruction *foldComplexAndOrPatterns(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + const Instruction::BinaryOps Opcode = I.getOpcode(); + assert(Opcode == Instruction::And || Opcode == Instruction::Or); + + // Flip the logic operation. + const Instruction::BinaryOps FlippedOpcode = + (Opcode == Instruction::And) ? Instruction::Or : Instruction::And; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *A, *B, *C, *X, *Y, *Dummy; + + // Match following expressions: + // (~(A | B) & C) + // (~(A & B) | C) + // Captures X = ~(A | B) or ~(A & B) + const auto matchNotOrAnd = + [Opcode, FlippedOpcode](Value *Op, auto m_A, auto m_B, auto m_C, + Value *&X, bool CountUses = false) -> bool { + if (CountUses && !Op->hasOneUse()) + return false; + + if (match(Op, m_c_BinOp(FlippedOpcode, + m_CombineAnd(m_Value(X), + m_Not(m_c_BinOp(Opcode, m_A, m_B))), + m_C))) + return !CountUses || X->hasOneUse(); + + return false; + }; + + // (~(A | B) & C) | ... --> ... + // (~(A & B) | C) & ... --> ... + // TODO: One use checks are conservative. We just need to check that a total + // number of multiple used values does not exceed reduction + // in operations. + if (matchNotOrAnd(Op0, m_Value(A), m_Value(B), m_Value(C), X)) { + // (~(A | B) & C) | (~(A | C) & B) --> (B ^ C) & ~A + // (~(A & B) | C) & (~(A & C) | B) --> ~((B ^ C) & A) + if (matchNotOrAnd(Op1, m_Specific(A), m_Specific(C), m_Specific(B), Dummy, + true)) { + Value *Xor = Builder.CreateXor(B, C); + return (Opcode == Instruction::Or) + ? BinaryOperator::CreateAnd(Xor, Builder.CreateNot(A)) + : BinaryOperator::CreateNot(Builder.CreateAnd(Xor, A)); + } + + // (~(A | B) & C) | (~(B | C) & A) --> (A ^ C) & ~B + // (~(A & B) | C) & (~(B & C) | A) --> ~((A ^ C) & B) + if (matchNotOrAnd(Op1, m_Specific(B), m_Specific(C), m_Specific(A), Dummy, + true)) { + Value *Xor = Builder.CreateXor(A, C); + return (Opcode == Instruction::Or) + ? BinaryOperator::CreateAnd(Xor, Builder.CreateNot(B)) + : BinaryOperator::CreateNot(Builder.CreateAnd(Xor, B)); + } + + // (~(A | B) & C) | ~(A | C) --> ~((B & C) | A) + // (~(A & B) | C) & ~(A & C) --> ~((B | C) & A) + if (match(Op1, m_OneUse(m_Not(m_OneUse( + m_c_BinOp(Opcode, m_Specific(A), m_Specific(C))))))) + return BinaryOperator::CreateNot(Builder.CreateBinOp( + Opcode, Builder.CreateBinOp(FlippedOpcode, B, C), A)); + + // (~(A | B) & C) | ~(B | C) --> ~((A & C) | B) + // (~(A & B) | C) & ~(B & C) --> ~((A | C) & B) + if (match(Op1, m_OneUse(m_Not(m_OneUse( + m_c_BinOp(Opcode, m_Specific(B), m_Specific(C))))))) + return BinaryOperator::CreateNot(Builder.CreateBinOp( + Opcode, Builder.CreateBinOp(FlippedOpcode, A, C), B)); + + // (~(A | B) & C) | ~(C | (A ^ B)) --> ~((A | B) & (C | (A ^ B))) + // Note, the pattern with swapped and/or is not handled because the + // result is more undefined than a source: + // (~(A & B) | C) & ~(C & (A ^ B)) --> (A ^ B ^ C) | ~(A | C) is invalid. + if (Opcode == Instruction::Or && Op0->hasOneUse() && + match(Op1, m_OneUse(m_Not(m_CombineAnd( + m_Value(Y), + m_c_BinOp(Opcode, m_Specific(C), + m_c_Xor(m_Specific(A), m_Specific(B)))))))) { + // X = ~(A | B) + // Y = (C | (A ^ B) + Value *Or = cast<BinaryOperator>(X)->getOperand(0); + return BinaryOperator::CreateNot(Builder.CreateAnd(Or, Y)); + } + } + + // (~A & B & C) | ... --> ... + // (~A | B | C) | ... --> ... + // TODO: One use checks are conservative. We just need to check that a total + // number of multiple used values does not exceed reduction + // in operations. + if (match(Op0, + m_OneUse(m_c_BinOp(FlippedOpcode, + m_BinOp(FlippedOpcode, m_Value(B), m_Value(C)), + m_CombineAnd(m_Value(X), m_Not(m_Value(A)))))) || + match(Op0, m_OneUse(m_c_BinOp( + FlippedOpcode, + m_c_BinOp(FlippedOpcode, m_Value(C), + m_CombineAnd(m_Value(X), m_Not(m_Value(A)))), + m_Value(B))))) { + // X = ~A + // (~A & B & C) | ~(A | B | C) --> ~(A | (B ^ C)) + // (~A | B | C) & ~(A & B & C) --> (~A | (B ^ C)) + if (match(Op1, m_OneUse(m_Not(m_c_BinOp( + Opcode, m_c_BinOp(Opcode, m_Specific(A), m_Specific(B)), + m_Specific(C))))) || + match(Op1, m_OneUse(m_Not(m_c_BinOp( + Opcode, m_c_BinOp(Opcode, m_Specific(B), m_Specific(C)), + m_Specific(A))))) || + match(Op1, m_OneUse(m_Not(m_c_BinOp( + Opcode, m_c_BinOp(Opcode, m_Specific(A), m_Specific(C)), + m_Specific(B)))))) { + Value *Xor = Builder.CreateXor(B, C); + return (Opcode == Instruction::Or) + ? BinaryOperator::CreateNot(Builder.CreateOr(Xor, A)) + : BinaryOperator::CreateOr(Xor, X); + } + + // (~A & B & C) | ~(A | B) --> (C | ~B) & ~A + // (~A | B | C) & ~(A & B) --> (C & ~B) | ~A + if (match(Op1, m_OneUse(m_Not(m_OneUse( + m_c_BinOp(Opcode, m_Specific(A), m_Specific(B))))))) + return BinaryOperator::Create( + FlippedOpcode, Builder.CreateBinOp(Opcode, C, Builder.CreateNot(B)), + X); + + // (~A & B & C) | ~(A | C) --> (B | ~C) & ~A + // (~A | B | C) & ~(A & C) --> (B & ~C) | ~A + if (match(Op1, m_OneUse(m_Not(m_OneUse( + m_c_BinOp(Opcode, m_Specific(A), m_Specific(C))))))) + return BinaryOperator::Create( + FlippedOpcode, Builder.CreateBinOp(Opcode, B, Builder.CreateNot(C)), + X); + } + + return nullptr; +} + +// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches +// here. We should standardize that construct where it is needed or choose some +// other way to ensure that commutated variants of patterns are not missed. +Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { + Type *Ty = I.getType(); + + if (Value *V = simplifyAndInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + if (SimplifyDemandedInstructionBits(I)) + return &I; + + // Do this before using distributive laws to catch simple and/or/not patterns. + if (Instruction *Xor = foldAndToXor(I, Builder)) + return Xor; + + if (Instruction *X = foldComplexAndOrPatterns(I, Builder)) + return X; + + // (A|B)&(A|C) -> A|(B&C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return replaceInstUsesWith(I, V); + + if (Value *V = SimplifyBSwap(I, Builder)) + return replaceInstUsesWith(I, V); + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + Value *X, *Y; + if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X)))) && + match(Op1, m_One())) { + // (1 << X) & 1 --> zext(X == 0) + // (1 >> X) & 1 --> zext(X == 0) + Value *IsZero = Builder.CreateICmpEQ(X, ConstantInt::get(Ty, 0)); + return new ZExtInst(IsZero, Ty); + } + + const APInt *C; + if (match(Op1, m_APInt(C))) { + const APInt *XorC; + if (match(Op0, m_OneUse(m_Xor(m_Value(X), m_APInt(XorC))))) { + // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2) + Constant *NewC = ConstantInt::get(Ty, *C & *XorC); + Value *And = Builder.CreateAnd(X, Op1); + And->takeName(Op0); + return BinaryOperator::CreateXor(And, NewC); + } + + const APInt *OrC; + if (match(Op0, m_OneUse(m_Or(m_Value(X), m_APInt(OrC))))) { + // (X | C1) & C2 --> (X & C2^(C1&C2)) | (C1&C2) + // NOTE: This reduces the number of bits set in the & mask, which + // can expose opportunities for store narrowing for scalars. + // NOTE: SimplifyDemandedBits should have already removed bits from C1 + // that aren't set in C2. Meaning we can replace (C1&C2) with C1 in + // above, but this feels safer. + APInt Together = *C & *OrC; + Value *And = Builder.CreateAnd(X, ConstantInt::get(Ty, Together ^ *C)); + And->takeName(Op0); + return BinaryOperator::CreateOr(And, ConstantInt::get(Ty, Together)); + } + + unsigned Width = Ty->getScalarSizeInBits(); + const APInt *ShiftC; + if (match(Op0, m_OneUse(m_SExt(m_AShr(m_Value(X), m_APInt(ShiftC)))))) { + if (*C == APInt::getLowBitsSet(Width, Width - ShiftC->getZExtValue())) { + // We are clearing high bits that were potentially set by sext+ashr: + // and (sext (ashr X, ShiftC)), C --> lshr (sext X), ShiftC + Value *Sext = Builder.CreateSExt(X, Ty); + Constant *ShAmtC = ConstantInt::get(Ty, ShiftC->zext(Width)); + return BinaryOperator::CreateLShr(Sext, ShAmtC); + } + } + + // If this 'and' clears the sign-bits added by ashr, replace with lshr: + // and (ashr X, ShiftC), C --> lshr X, ShiftC + if (match(Op0, m_AShr(m_Value(X), m_APInt(ShiftC))) && ShiftC->ult(Width) && + C->isMask(Width - ShiftC->getZExtValue())) + return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, *ShiftC)); + + const APInt *AddC; + if (match(Op0, m_Add(m_Value(X), m_APInt(AddC)))) { + // If we add zeros to every bit below a mask, the add has no effect: + // (X + AddC) & LowMaskC --> X & LowMaskC + unsigned Ctlz = C->countLeadingZeros(); + APInt LowMask(APInt::getLowBitsSet(Width, Width - Ctlz)); + if ((*AddC & LowMask).isZero()) + return BinaryOperator::CreateAnd(X, Op1); + + // If we are masking the result of the add down to exactly one bit and + // the constant we are adding has no bits set below that bit, then the + // add is flipping a single bit. Example: + // (X + 4) & 4 --> (X & 4) ^ 4 + if (Op0->hasOneUse() && C->isPowerOf2() && (*AddC & (*C - 1)) == 0) { + assert((*C & *AddC) != 0 && "Expected common bit"); + Value *NewAnd = Builder.CreateAnd(X, Op1); + return BinaryOperator::CreateXor(NewAnd, Op1); + } + } + + // ((C1 OP zext(X)) & C2) -> zext((C1 OP X) & C2) if C2 fits in the + // bitwidth of X and OP behaves well when given trunc(C1) and X. + auto isNarrowableBinOpcode = [](BinaryOperator *B) { + switch (B->getOpcode()) { + case Instruction::Xor: + case Instruction::Or: + case Instruction::Mul: + case Instruction::Add: + case Instruction::Sub: + return true; + default: + return false; + } + }; + BinaryOperator *BO; + if (match(Op0, m_OneUse(m_BinOp(BO))) && isNarrowableBinOpcode(BO)) { + Instruction::BinaryOps BOpcode = BO->getOpcode(); + Value *X; + const APInt *C1; + // TODO: The one-use restrictions could be relaxed a little if the AND + // is going to be removed. + // Try to narrow the 'and' and a binop with constant operand: + // and (bo (zext X), C1), C --> zext (and (bo X, TruncC1), TruncC) + if (match(BO, m_c_BinOp(m_OneUse(m_ZExt(m_Value(X))), m_APInt(C1))) && + C->isIntN(X->getType()->getScalarSizeInBits())) { + unsigned XWidth = X->getType()->getScalarSizeInBits(); + Constant *TruncC1 = ConstantInt::get(X->getType(), C1->trunc(XWidth)); + Value *BinOp = isa<ZExtInst>(BO->getOperand(0)) + ? Builder.CreateBinOp(BOpcode, X, TruncC1) + : Builder.CreateBinOp(BOpcode, TruncC1, X); + Constant *TruncC = ConstantInt::get(X->getType(), C->trunc(XWidth)); + Value *And = Builder.CreateAnd(BinOp, TruncC); + return new ZExtInst(And, Ty); + } + + // Similar to above: if the mask matches the zext input width, then the + // 'and' can be eliminated, so we can truncate the other variable op: + // and (bo (zext X), Y), C --> zext (bo X, (trunc Y)) + if (isa<Instruction>(BO->getOperand(0)) && + match(BO->getOperand(0), m_OneUse(m_ZExt(m_Value(X)))) && + C->isMask(X->getType()->getScalarSizeInBits())) { + Y = BO->getOperand(1); + Value *TrY = Builder.CreateTrunc(Y, X->getType(), Y->getName() + ".tr"); + Value *NewBO = + Builder.CreateBinOp(BOpcode, X, TrY, BO->getName() + ".narrow"); + return new ZExtInst(NewBO, Ty); + } + // and (bo Y, (zext X)), C --> zext (bo (trunc Y), X) + if (isa<Instruction>(BO->getOperand(1)) && + match(BO->getOperand(1), m_OneUse(m_ZExt(m_Value(X)))) && + C->isMask(X->getType()->getScalarSizeInBits())) { + Y = BO->getOperand(0); + Value *TrY = Builder.CreateTrunc(Y, X->getType(), Y->getName() + ".tr"); + Value *NewBO = + Builder.CreateBinOp(BOpcode, TrY, X, BO->getName() + ".narrow"); + return new ZExtInst(NewBO, Ty); + } + } + + // This is intentionally placed after the narrowing transforms for + // efficiency (transform directly to the narrow logic op if possible). + // If the mask is only needed on one incoming arm, push the 'and' op up. + if (match(Op0, m_OneUse(m_Xor(m_Value(X), m_Value(Y)))) || + match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { + APInt NotAndMask(~(*C)); + BinaryOperator::BinaryOps BinOp = cast<BinaryOperator>(Op0)->getOpcode(); + if (MaskedValueIsZero(X, NotAndMask, 0, &I)) { + // Not masking anything out for the LHS, move mask to RHS. + // and ({x}or X, Y), C --> {x}or X, (and Y, C) + Value *NewRHS = Builder.CreateAnd(Y, Op1, Y->getName() + ".masked"); + return BinaryOperator::Create(BinOp, X, NewRHS); + } + if (!isa<Constant>(Y) && MaskedValueIsZero(Y, NotAndMask, 0, &I)) { + // Not masking anything out for the RHS, move mask to LHS. + // and ({x}or X, Y), C --> {x}or (and X, C), Y + Value *NewLHS = Builder.CreateAnd(X, Op1, X->getName() + ".masked"); + return BinaryOperator::Create(BinOp, NewLHS, Y); + } + } + + // When the mask is a power-of-2 constant and op0 is a shifted-power-of-2 + // constant, test if the shift amount equals the offset bit index: + // (ShiftC << X) & C --> X == (log2(C) - log2(ShiftC)) ? C : 0 + // (ShiftC >> X) & C --> X == (log2(ShiftC) - log2(C)) ? C : 0 + if (C->isPowerOf2() && + match(Op0, m_OneUse(m_LogicalShift(m_Power2(ShiftC), m_Value(X))))) { + int Log2ShiftC = ShiftC->exactLogBase2(); + int Log2C = C->exactLogBase2(); + bool IsShiftLeft = + cast<BinaryOperator>(Op0)->getOpcode() == Instruction::Shl; + int BitNum = IsShiftLeft ? Log2C - Log2ShiftC : Log2ShiftC - Log2C; + assert(BitNum >= 0 && "Expected demanded bits to handle impossible mask"); + Value *Cmp = Builder.CreateICmpEQ(X, ConstantInt::get(Ty, BitNum)); + return SelectInst::Create(Cmp, ConstantInt::get(Ty, *C), + ConstantInt::getNullValue(Ty)); + } + + Constant *C1, *C2; + const APInt *C3 = C; + Value *X; + if (C3->isPowerOf2()) { + Constant *Log2C3 = ConstantInt::get(Ty, C3->countTrailingZeros()); + if (match(Op0, m_OneUse(m_LShr(m_Shl(m_ImmConstant(C1), m_Value(X)), + m_ImmConstant(C2)))) && + match(C1, m_Power2())) { + Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1); + Constant *LshrC = ConstantExpr::getAdd(C2, Log2C3); + KnownBits KnownLShrc = computeKnownBits(LshrC, 0, nullptr); + if (KnownLShrc.getMaxValue().ult(Width)) { + // iff C1,C3 is pow2 and C2 + cttz(C3) < BitWidth: + // ((C1 << X) >> C2) & C3 -> X == (cttz(C3)+C2-cttz(C1)) ? C3 : 0 + Constant *CmpC = ConstantExpr::getSub(LshrC, Log2C1); + Value *Cmp = Builder.CreateICmpEQ(X, CmpC); + return SelectInst::Create(Cmp, ConstantInt::get(Ty, *C3), + ConstantInt::getNullValue(Ty)); + } + } + + if (match(Op0, m_OneUse(m_Shl(m_LShr(m_ImmConstant(C1), m_Value(X)), + m_ImmConstant(C2)))) && + match(C1, m_Power2())) { + Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1); + Constant *Cmp = + ConstantExpr::getCompare(ICmpInst::ICMP_ULT, Log2C3, C2); + if (Cmp->isZeroValue()) { + // iff C1,C3 is pow2 and Log2(C3) >= C2: + // ((C1 >> X) << C2) & C3 -> X == (cttz(C1)+C2-cttz(C3)) ? C3 : 0 + Constant *ShlC = ConstantExpr::getAdd(C2, Log2C1); + Constant *CmpC = ConstantExpr::getSub(ShlC, Log2C3); + Value *Cmp = Builder.CreateICmpEQ(X, CmpC); + return SelectInst::Create(Cmp, ConstantInt::get(Ty, *C3), + ConstantInt::getNullValue(Ty)); + } + } + } + } + + if (match(&I, m_And(m_OneUse(m_Shl(m_ZExt(m_Value(X)), m_Value(Y))), + m_SignMask())) && + match(Y, m_SpecificInt_ICMP( + ICmpInst::Predicate::ICMP_EQ, + APInt(Ty->getScalarSizeInBits(), + Ty->getScalarSizeInBits() - + X->getType()->getScalarSizeInBits())))) { + auto *SExt = Builder.CreateSExt(X, Ty, X->getName() + ".signext"); + auto *SanitizedSignMask = cast<Constant>(Op1); + // We must be careful with the undef elements of the sign bit mask, however: + // the mask elt can be undef iff the shift amount for that lane was undef, + // otherwise we need to sanitize undef masks to zero. + SanitizedSignMask = Constant::replaceUndefsWith( + SanitizedSignMask, ConstantInt::getNullValue(Ty->getScalarType())); + SanitizedSignMask = + Constant::mergeUndefsWith(SanitizedSignMask, cast<Constant>(Y)); + return BinaryOperator::CreateAnd(SExt, SanitizedSignMask); + } + + if (Instruction *Z = narrowMaskedBinOp(I)) + return Z; + + if (I.getType()->isIntOrIntVectorTy(1)) { + if (auto *SI0 = dyn_cast<SelectInst>(Op0)) { + if (auto *I = + foldAndOrOfSelectUsingImpliedCond(Op1, *SI0, /* IsAnd */ true)) + return I; + } + if (auto *SI1 = dyn_cast<SelectInst>(Op1)) { + if (auto *I = + foldAndOrOfSelectUsingImpliedCond(Op0, *SI1, /* IsAnd */ true)) + return I; + } + } + + if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) + return FoldedLogic; + + if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) + return DeMorgan; + + { + Value *A, *B, *C; + // A & (A ^ B) --> A & ~B + if (match(Op1, m_OneUse(m_c_Xor(m_Specific(Op0), m_Value(B))))) + return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(B)); + // (A ^ B) & A --> A & ~B + if (match(Op0, m_OneUse(m_c_Xor(m_Specific(Op1), m_Value(B))))) + return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(B)); + + // A & ~(A ^ B) --> A & B + if (match(Op1, m_Not(m_c_Xor(m_Specific(Op0), m_Value(B))))) + return BinaryOperator::CreateAnd(Op0, B); + // ~(A ^ B) & A --> A & B + if (match(Op0, m_Not(m_c_Xor(m_Specific(Op1), m_Value(B))))) + return BinaryOperator::CreateAnd(Op1, B); + + // (A ^ B) & ((B ^ C) ^ A) -> (A ^ B) & ~C + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) + if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) + if (Op1->hasOneUse() || isFreeToInvert(C, C->hasOneUse())) + return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(C)); + + // ((A ^ C) ^ B) & (B ^ A) -> (B ^ A) & ~C + if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) + if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) + if (Op0->hasOneUse() || isFreeToInvert(C, C->hasOneUse())) + return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(C)); + + // (A | B) & (~A ^ B) -> A & B + // (A | B) & (B ^ ~A) -> A & B + // (B | A) & (~A ^ B) -> A & B + // (B | A) & (B ^ ~A) -> A & B + if (match(Op1, m_c_Xor(m_Not(m_Value(A)), m_Value(B))) && + match(Op0, m_c_Or(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateAnd(A, B); + + // (~A ^ B) & (A | B) -> A & B + // (~A ^ B) & (B | A) -> A & B + // (B ^ ~A) & (A | B) -> A & B + // (B ^ ~A) & (B | A) -> A & B + if (match(Op0, m_c_Xor(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateAnd(A, B); + + // (~A | B) & (A ^ B) -> ~A & B + // (~A | B) & (B ^ A) -> ~A & B + // (B | ~A) & (A ^ B) -> ~A & B + // (B | ~A) & (B ^ A) -> ~A & B + if (match(Op0, m_c_Or(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_c_Xor(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateAnd(Builder.CreateNot(A), B); + + // (A ^ B) & (~A | B) -> ~A & B + // (B ^ A) & (~A | B) -> ~A & B + // (A ^ B) & (B | ~A) -> ~A & B + // (B ^ A) & (B | ~A) -> ~A & B + if (match(Op1, m_c_Or(m_Not(m_Value(A)), m_Value(B))) && + match(Op0, m_c_Xor(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateAnd(Builder.CreateNot(A), B); + } + + { + ICmpInst *LHS = dyn_cast<ICmpInst>(Op0); + ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); + if (LHS && RHS) + if (Value *Res = foldAndOrOfICmps(LHS, RHS, I, /* IsAnd */ true)) + return replaceInstUsesWith(I, Res); + + // TODO: Make this recursive; it's a little tricky because an arbitrary + // number of 'and' instructions might have to be created. + if (LHS && match(Op1, m_OneUse(m_LogicalAnd(m_Value(X), m_Value(Y))))) { + bool IsLogical = isa<SelectInst>(Op1); + // LHS & (X && Y) --> (LHS && X) && Y + if (auto *Cmp = dyn_cast<ICmpInst>(X)) + if (Value *Res = + foldAndOrOfICmps(LHS, Cmp, I, /* IsAnd */ true, IsLogical)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalAnd(Res, Y) + : Builder.CreateAnd(Res, Y)); + // LHS & (X && Y) --> X && (LHS & Y) + if (auto *Cmp = dyn_cast<ICmpInst>(Y)) + if (Value *Res = foldAndOrOfICmps(LHS, Cmp, I, /* IsAnd */ true, + /* IsLogical */ false)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalAnd(X, Res) + : Builder.CreateAnd(X, Res)); + } + if (RHS && match(Op0, m_OneUse(m_LogicalAnd(m_Value(X), m_Value(Y))))) { + bool IsLogical = isa<SelectInst>(Op0); + // (X && Y) & RHS --> (X && RHS) && Y + if (auto *Cmp = dyn_cast<ICmpInst>(X)) + if (Value *Res = + foldAndOrOfICmps(Cmp, RHS, I, /* IsAnd */ true, IsLogical)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalAnd(Res, Y) + : Builder.CreateAnd(Res, Y)); + // (X && Y) & RHS --> X && (Y & RHS) + if (auto *Cmp = dyn_cast<ICmpInst>(Y)) + if (Value *Res = foldAndOrOfICmps(Cmp, RHS, I, /* IsAnd */ true, + /* IsLogical */ false)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalAnd(X, Res) + : Builder.CreateAnd(X, Res)); + } + } + + if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) + if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) + if (Value *Res = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true)) + return replaceInstUsesWith(I, Res); + + if (Instruction *FoldedFCmps = reassociateFCmps(I, Builder)) + return FoldedFCmps; + + if (Instruction *CastedAnd = foldCastedBitwiseLogic(I)) + return CastedAnd; + + if (Instruction *Sel = foldBinopOfSextBoolToSelect(I)) + return Sel; + + // and(sext(A), B) / and(B, sext(A)) --> A ? B : 0, where A is i1 or <N x i1>. + // TODO: Move this into foldBinopOfSextBoolToSelect as a more generalized fold + // with binop identity constant. But creating a select with non-constant + // arm may not be reversible due to poison semantics. Is that a good + // canonicalization? + Value *A; + if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, Op1, Constant::getNullValue(Ty)); + if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, Op0, Constant::getNullValue(Ty)); + + // (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0 + unsigned FullShift = Ty->getScalarSizeInBits() - 1; + if (match(&I, m_c_And(m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))), + m_Value(Y)))) { + Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); + return SelectInst::Create(IsNeg, Y, ConstantInt::getNullValue(Ty)); + } + // If there's a 'not' of the shifted value, swap the select operands: + // ~(iN X s>> (N-1)) & Y --> (X s< 0) ? 0 : Y + if (match(&I, m_c_And(m_OneUse(m_Not( + m_AShr(m_Value(X), m_SpecificInt(FullShift)))), + m_Value(Y)))) { + Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); + return SelectInst::Create(IsNeg, ConstantInt::getNullValue(Ty), Y); + } + + // (~x) & y --> ~(x | (~y)) iff that gets rid of inversions + if (sinkNotIntoOtherHandOfAndOrOr(I)) + return &I; + + // An and recurrence w/loop invariant step is equivelent to (and start, step) + PHINode *PN = nullptr; + Value *Start = nullptr, *Step = nullptr; + if (matchSimpleRecurrence(&I, PN, Start, Step) && DT.dominates(Step, PN)) + return replaceInstUsesWith(I, Builder.CreateAnd(Start, Step)); + + return nullptr; +} + +Instruction *InstCombinerImpl::matchBSwapOrBitReverse(Instruction &I, + bool MatchBSwaps, + bool MatchBitReversals) { + SmallVector<Instruction *, 4> Insts; + if (!recognizeBSwapOrBitReverseIdiom(&I, MatchBSwaps, MatchBitReversals, + Insts)) + return nullptr; + Instruction *LastInst = Insts.pop_back_val(); + LastInst->removeFromParent(); + + for (auto *Inst : Insts) + Worklist.push(Inst); + return LastInst; +} + +/// Match UB-safe variants of the funnel shift intrinsic. +static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) { + // TODO: Can we reduce the code duplication between this and the related + // rotate matching code under visitSelect and visitTrunc? + unsigned Width = Or.getType()->getScalarSizeInBits(); + + // First, find an or'd pair of opposite shifts: + // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1) + BinaryOperator *Or0, *Or1; + if (!match(Or.getOperand(0), m_BinOp(Or0)) || + !match(Or.getOperand(1), m_BinOp(Or1))) + return nullptr; + + Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) || + Or0->getOpcode() == Or1->getOpcode()) + return nullptr; + + // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). + if (Or0->getOpcode() == BinaryOperator::LShr) { + std::swap(Or0, Or1); + std::swap(ShVal0, ShVal1); + std::swap(ShAmt0, ShAmt1); + } + assert(Or0->getOpcode() == BinaryOperator::Shl && + Or1->getOpcode() == BinaryOperator::LShr && + "Illegal or(shift,shift) pair"); + + // Match the shift amount operands for a funnel shift pattern. This always + // matches a subtraction on the R operand. + auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { + // Check for constant shift amounts that sum to the bitwidth. + const APInt *LI, *RI; + if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI))) + if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width) + return ConstantInt::get(L->getType(), *LI); + + Constant *LC, *RC; + if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) && + match(L, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && + match(R, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && + match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width))) + return ConstantExpr::mergeUndefsWith(LC, RC); + + // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width. + // We limit this to X < Width in case the backend re-expands the intrinsic, + // and has to reintroduce a shift modulo operation (InstCombine might remove + // it after this fold). This still doesn't guarantee that the final codegen + // will match this original pattern. + if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) { + KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or); + return KnownL.getMaxValue().ult(Width) ? L : nullptr; + } + + // For non-constant cases, the following patterns currently only work for + // rotation patterns. + // TODO: Add general funnel-shift compatible patterns. + if (ShVal0 != ShVal1) + return nullptr; + + // For non-constant cases we don't support non-pow2 shift masks. + // TODO: Is it worth matching urem as well? + if (!isPowerOf2_32(Width)) + return nullptr; + + // The shift amount may be masked with negation: + // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1))) + Value *X; + unsigned Mask = Width - 1; + if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) && + match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))) + return X; + + // Similar to above, but the shift amount may be extended after masking, + // so return the extended value as the parameter for the intrinsic. + if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && + match(R, m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))), + m_SpecificInt(Mask)))) + return L; + + if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && + match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))) + return L; + + return nullptr; + }; + + Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width); + bool IsFshl = true; // Sub on LSHR. + if (!ShAmt) { + ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width); + IsFshl = false; // Sub on SHL. + } + if (!ShAmt) + return nullptr; + + Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; + Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType()); + return CallInst::Create(F, {ShVal0, ShVal1, ShAmt}); +} + +/// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns. +static Instruction *matchOrConcat(Instruction &Or, + InstCombiner::BuilderTy &Builder) { + assert(Or.getOpcode() == Instruction::Or && "bswap requires an 'or'"); + Value *Op0 = Or.getOperand(0), *Op1 = Or.getOperand(1); + Type *Ty = Or.getType(); + + unsigned Width = Ty->getScalarSizeInBits(); + if ((Width & 1) != 0) + return nullptr; + unsigned HalfWidth = Width / 2; + + // Canonicalize zext (lower half) to LHS. + if (!isa<ZExtInst>(Op0)) + std::swap(Op0, Op1); + + // Find lower/upper half. + Value *LowerSrc, *ShlVal, *UpperSrc; + const APInt *C; + if (!match(Op0, m_OneUse(m_ZExt(m_Value(LowerSrc)))) || + !match(Op1, m_OneUse(m_Shl(m_Value(ShlVal), m_APInt(C)))) || + !match(ShlVal, m_OneUse(m_ZExt(m_Value(UpperSrc))))) + return nullptr; + if (*C != HalfWidth || LowerSrc->getType() != UpperSrc->getType() || + LowerSrc->getType()->getScalarSizeInBits() != HalfWidth) + return nullptr; + + auto ConcatIntrinsicCalls = [&](Intrinsic::ID id, Value *Lo, Value *Hi) { + Value *NewLower = Builder.CreateZExt(Lo, Ty); + Value *NewUpper = Builder.CreateZExt(Hi, Ty); + NewUpper = Builder.CreateShl(NewUpper, HalfWidth); + Value *BinOp = Builder.CreateOr(NewLower, NewUpper); + Function *F = Intrinsic::getDeclaration(Or.getModule(), id, Ty); + return Builder.CreateCall(F, BinOp); + }; + + // BSWAP: Push the concat down, swapping the lower/upper sources. + // concat(bswap(x),bswap(y)) -> bswap(concat(x,y)) + Value *LowerBSwap, *UpperBSwap; + if (match(LowerSrc, m_BSwap(m_Value(LowerBSwap))) && + match(UpperSrc, m_BSwap(m_Value(UpperBSwap)))) + return ConcatIntrinsicCalls(Intrinsic::bswap, UpperBSwap, LowerBSwap); + + // BITREVERSE: Push the concat down, swapping the lower/upper sources. + // concat(bitreverse(x),bitreverse(y)) -> bitreverse(concat(x,y)) + Value *LowerBRev, *UpperBRev; + if (match(LowerSrc, m_BitReverse(m_Value(LowerBRev))) && + match(UpperSrc, m_BitReverse(m_Value(UpperBRev)))) + return ConcatIntrinsicCalls(Intrinsic::bitreverse, UpperBRev, LowerBRev); + + return nullptr; +} + +/// If all elements of two constant vectors are 0/-1 and inverses, return true. +static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { + unsigned NumElts = cast<FixedVectorType>(C1->getType())->getNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *EltC1 = C1->getAggregateElement(i); + Constant *EltC2 = C2->getAggregateElement(i); + if (!EltC1 || !EltC2) + return false; + + // One element must be all ones, and the other must be all zeros. + if (!((match(EltC1, m_Zero()) && match(EltC2, m_AllOnes())) || + (match(EltC2, m_Zero()) && match(EltC1, m_AllOnes())))) + return false; + } + return true; +} + +/// We have an expression of the form (A & C) | (B & D). If A is a scalar or +/// vector composed of all-zeros or all-ones values and is the bitwise 'not' of +/// B, it can be used as the condition operand of a select instruction. +Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { + // We may have peeked through bitcasts in the caller. + // Exit immediately if we don't have (vector) integer types. + Type *Ty = A->getType(); + if (!Ty->isIntOrIntVectorTy() || !B->getType()->isIntOrIntVectorTy()) + return nullptr; + + // If A is the 'not' operand of B and has enough signbits, we have our answer. + if (match(B, m_Not(m_Specific(A)))) { + // If these are scalars or vectors of i1, A can be used directly. + if (Ty->isIntOrIntVectorTy(1)) + return A; + + // If we look through a vector bitcast, the caller will bitcast the operands + // to match the condition's number of bits (N x i1). + // To make this poison-safe, disallow bitcast from wide element to narrow + // element. That could allow poison in lanes where it was not present in the + // original code. + A = peekThroughBitcast(A); + if (A->getType()->isIntOrIntVectorTy()) { + unsigned NumSignBits = ComputeNumSignBits(A); + if (NumSignBits == A->getType()->getScalarSizeInBits() && + NumSignBits <= Ty->getScalarSizeInBits()) + return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(A->getType())); + } + return nullptr; + } + + // If both operands are constants, see if the constants are inverse bitmasks. + Constant *AConst, *BConst; + if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst))) + if (AConst == ConstantExpr::getNot(BConst) && + ComputeNumSignBits(A) == Ty->getScalarSizeInBits()) + return Builder.CreateZExtOrTrunc(A, CmpInst::makeCmpResultType(Ty)); + + // Look for more complex patterns. The 'not' op may be hidden behind various + // casts. Look through sexts and bitcasts to find the booleans. + Value *Cond; + Value *NotB; + if (match(A, m_SExt(m_Value(Cond))) && + Cond->getType()->isIntOrIntVectorTy(1)) { + // A = sext i1 Cond; B = sext (not (i1 Cond)) + if (match(B, m_SExt(m_Not(m_Specific(Cond))))) + return Cond; + + // A = sext i1 Cond; B = not ({bitcast} (sext (i1 Cond))) + // TODO: The one-use checks are unnecessary or misplaced. If the caller + // checked for uses on logic ops/casts, that should be enough to + // make this transform worthwhile. + if (match(B, m_OneUse(m_Not(m_Value(NotB))))) { + NotB = peekThroughBitcast(NotB, true); + if (match(NotB, m_SExt(m_Specific(Cond)))) + return Cond; + } + } + + // All scalar (and most vector) possibilities should be handled now. + // Try more matches that only apply to non-splat constant vectors. + if (!Ty->isVectorTy()) + return nullptr; + + // If both operands are xor'd with constants using the same sexted boolean + // operand, see if the constants are inverse bitmasks. + // TODO: Use ConstantExpr::getNot()? + if (match(A, (m_Xor(m_SExt(m_Value(Cond)), m_Constant(AConst)))) && + match(B, (m_Xor(m_SExt(m_Specific(Cond)), m_Constant(BConst)))) && + Cond->getType()->isIntOrIntVectorTy(1) && + areInverseVectorBitmasks(AConst, BConst)) { + AConst = ConstantExpr::getTrunc(AConst, CmpInst::makeCmpResultType(Ty)); + return Builder.CreateXor(Cond, AConst); + } + return nullptr; +} + +/// We have an expression of the form (A & C) | (B & D). Try to simplify this +/// to "A' ? C : D", where A' is a boolean or vector of booleans. +Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, + Value *D) { + // The potential condition of the select may be bitcasted. In that case, look + // through its bitcast and the corresponding bitcast of the 'not' condition. + Type *OrigType = A->getType(); + A = peekThroughBitcast(A, true); + B = peekThroughBitcast(B, true); + if (Value *Cond = getSelectCondition(A, B)) { + // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) + // If this is a vector, we may need to cast to match the condition's length. + // The bitcasts will either all exist or all not exist. The builder will + // not create unnecessary casts if the types already match. + Type *SelTy = A->getType(); + if (auto *VecTy = dyn_cast<VectorType>(Cond->getType())) { + // For a fixed or scalable vector get N from <{vscale x} N x iM> + unsigned Elts = VecTy->getElementCount().getKnownMinValue(); + // For a fixed or scalable vector, get the size in bits of N x iM; for a + // scalar this is just M. + unsigned SelEltSize = SelTy->getPrimitiveSizeInBits().getKnownMinSize(); + Type *EltTy = Builder.getIntNTy(SelEltSize / Elts); + SelTy = VectorType::get(EltTy, VecTy->getElementCount()); + } + Value *BitcastC = Builder.CreateBitCast(C, SelTy); + Value *BitcastD = Builder.CreateBitCast(D, SelTy); + Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD); + return Builder.CreateBitCast(Select, OrigType); + } + + return nullptr; +} + +// (icmp eq X, 0) | (icmp ult Other, X) -> (icmp ule Other, X-1) +// (icmp ne X, 0) & (icmp uge Other, X) -> (icmp ugt Other, X-1) +Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, + IRBuilderBase &Builder) { + ICmpInst::Predicate LPred = + IsAnd ? LHS->getInversePredicate() : LHS->getPredicate(); + ICmpInst::Predicate RPred = + IsAnd ? RHS->getInversePredicate() : RHS->getPredicate(); + Value *LHS0 = LHS->getOperand(0); + if (LPred != ICmpInst::ICMP_EQ || !match(LHS->getOperand(1), m_Zero()) || + !LHS0->getType()->isIntOrIntVectorTy() || + !(LHS->hasOneUse() || RHS->hasOneUse())) + return nullptr; + + Value *Other; + if (RPred == ICmpInst::ICMP_ULT && RHS->getOperand(1) == LHS0) + Other = RHS->getOperand(0); + else if (RPred == ICmpInst::ICMP_UGT && RHS->getOperand(0) == LHS0) + Other = RHS->getOperand(1); + else + return nullptr; + + return Builder.CreateICmp( + IsAnd ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE, + Builder.CreateAdd(LHS0, Constant::getAllOnesValue(LHS0->getType())), + Other); +} + +/// Fold (icmp)&(icmp) or (icmp)|(icmp) if possible. +/// If IsLogical is true, then the and/or is in select form and the transform +/// must be poison-safe. +Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, + Instruction &I, bool IsAnd, + bool IsLogical) { + const SimplifyQuery Q = SQ.getWithInstruction(&I); + + // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) + // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) + // if K1 and K2 are a one-bit mask. + if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, &I, IsAnd, IsLogical)) + return V; + + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); + Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1); + const APInt *LHSC = nullptr, *RHSC = nullptr; + match(LHS1, m_APInt(LHSC)); + match(RHS1, m_APInt(RHSC)); + + // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) + // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) + if (predicatesFoldable(PredL, PredR)) { + if (LHS0 == RHS1 && LHS1 == RHS0) { + PredL = ICmpInst::getSwappedPredicate(PredL); + std::swap(LHS0, LHS1); + } + if (LHS0 == RHS0 && LHS1 == RHS1) { + unsigned Code = IsAnd ? getICmpCode(PredL) & getICmpCode(PredR) + : getICmpCode(PredL) | getICmpCode(PredR); + bool IsSigned = LHS->isSigned() || RHS->isSigned(); + return getNewICmpValue(Code, IsSigned, LHS0, LHS1, Builder); + } + } + + // handle (roughly): + // (icmp ne (A & B), C) | (icmp ne (A & D), E) + // (icmp eq (A & B), C) & (icmp eq (A & D), E) + if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder)) + return V; + + // TODO: One of these directions is fine with logical and/or, the other could + // be supported by inserting freeze. + if (!IsLogical) { + if (Value *V = foldAndOrOfICmpEqZeroAndICmp(LHS, RHS, IsAnd, Builder)) + return V; + if (Value *V = foldAndOrOfICmpEqZeroAndICmp(RHS, LHS, IsAnd, Builder)) + return V; + } + + // TODO: Verify whether this is safe for logical and/or. + if (!IsLogical) { + if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, IsAnd, Builder, Q)) + return V; + if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, IsAnd, Builder, Q)) + return V; + } + + if (Value *V = foldIsPowerOf2OrZero(LHS, RHS, IsAnd, Builder)) + return V; + if (Value *V = foldIsPowerOf2OrZero(RHS, LHS, IsAnd, Builder)) + return V; + + // TODO: One of these directions is fine with logical and/or, the other could + // be supported by inserting freeze. + if (!IsLogical) { + // E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n + // E.g. (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n + if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/!IsAnd)) + return V; + + // E.g. (icmp sgt x, n) | (icmp slt x, 0) --> icmp ugt x, n + // E.g. (icmp slt x, n) & (icmp sge x, 0) --> icmp ult x, n + if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/!IsAnd)) + return V; + } + + // TODO: Add conjugated or fold, check whether it is safe for logical and/or. + if (IsAnd && !IsLogical) + if (Value *V = foldSignedTruncationCheck(LHS, RHS, I, Builder)) + return V; + + if (Value *V = foldIsPowerOf2(LHS, RHS, IsAnd, Builder)) + return V; + + // TODO: Verify whether this is safe for logical and/or. + if (!IsLogical) { + if (Value *X = foldUnsignedUnderflowCheck(LHS, RHS, IsAnd, Q, Builder)) + return X; + if (Value *X = foldUnsignedUnderflowCheck(RHS, LHS, IsAnd, Q, Builder)) + return X; + } + + if (Value *X = foldEqOfParts(LHS, RHS, IsAnd)) + return X; + + // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) + // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) + // TODO: Remove this when foldLogOpOfMaskedICmps can handle undefs. + if (!IsLogical && PredL == (IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) && + PredL == PredR && match(LHS1, m_ZeroInt()) && match(RHS1, m_ZeroInt()) && + LHS0->getType() == RHS0->getType()) { + Value *NewOr = Builder.CreateOr(LHS0, RHS0); + return Builder.CreateICmp(PredL, NewOr, + Constant::getNullValue(NewOr->getType())); + } + + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). + if (!LHSC || !RHSC) + return nullptr; + + // (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2 + // (trunc x) != C1 | (and x, CA) != C2 -> (and x, CA|CMAX) != C1|C2 + // where CMAX is the all ones value for the truncated type, + // iff the lower bits of C2 and CA are zero. + if (PredL == (IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) && + PredL == PredR && LHS->hasOneUse() && RHS->hasOneUse()) { + Value *V; + const APInt *AndC, *SmallC = nullptr, *BigC = nullptr; + + // (trunc x) == C1 & (and x, CA) == C2 + // (and x, CA) == C2 & (trunc x) == C1 + if (match(RHS0, m_Trunc(m_Value(V))) && + match(LHS0, m_And(m_Specific(V), m_APInt(AndC)))) { + SmallC = RHSC; + BigC = LHSC; + } else if (match(LHS0, m_Trunc(m_Value(V))) && + match(RHS0, m_And(m_Specific(V), m_APInt(AndC)))) { + SmallC = LHSC; + BigC = RHSC; + } + + if (SmallC && BigC) { + unsigned BigBitSize = BigC->getBitWidth(); + unsigned SmallBitSize = SmallC->getBitWidth(); + + // Check that the low bits are zero. + APInt Low = APInt::getLowBitsSet(BigBitSize, SmallBitSize); + if ((Low & *AndC).isZero() && (Low & *BigC).isZero()) { + Value *NewAnd = Builder.CreateAnd(V, Low | *AndC); + APInt N = SmallC->zext(BigBitSize) | *BigC; + Value *NewVal = ConstantInt::get(NewAnd->getType(), N); + return Builder.CreateICmp(PredL, NewAnd, NewVal); + } + } + } + + // Match naive pattern (and its inverted form) for checking if two values + // share same sign. An example of the pattern: + // (icmp slt (X & Y), 0) | (icmp sgt (X | Y), -1) -> (icmp sgt (X ^ Y), -1) + // Inverted form (example): + // (icmp slt (X | Y), 0) & (icmp sgt (X & Y), -1) -> (icmp slt (X ^ Y), 0) + bool TrueIfSignedL, TrueIfSignedR; + if (InstCombiner::isSignBitCheck(PredL, *LHSC, TrueIfSignedL) && + InstCombiner::isSignBitCheck(PredR, *RHSC, TrueIfSignedR) && + (RHS->hasOneUse() || LHS->hasOneUse())) { + Value *X, *Y; + if (IsAnd) { + if ((TrueIfSignedL && !TrueIfSignedR && + match(LHS0, m_Or(m_Value(X), m_Value(Y))) && + match(RHS0, m_c_And(m_Specific(X), m_Specific(Y)))) || + (!TrueIfSignedL && TrueIfSignedR && + match(LHS0, m_And(m_Value(X), m_Value(Y))) && + match(RHS0, m_c_Or(m_Specific(X), m_Specific(Y))))) { + Value *NewXor = Builder.CreateXor(X, Y); + return Builder.CreateIsNeg(NewXor); + } + } else { + if ((TrueIfSignedL && !TrueIfSignedR && + match(LHS0, m_And(m_Value(X), m_Value(Y))) && + match(RHS0, m_c_Or(m_Specific(X), m_Specific(Y)))) || + (!TrueIfSignedL && TrueIfSignedR && + match(LHS0, m_Or(m_Value(X), m_Value(Y))) && + match(RHS0, m_c_And(m_Specific(X), m_Specific(Y))))) { + Value *NewXor = Builder.CreateXor(X, Y); + return Builder.CreateIsNotNeg(NewXor); + } + } + } + + return foldAndOrOfICmpsUsingRanges(LHS, RHS, IsAnd); +} + +// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches +// here. We should standardize that construct where it is needed or choose some +// other way to ensure that commutated variants of patterns are not missed. +Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { + if (Value *V = simplifyOrInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + if (SimplifyDemandedInstructionBits(I)) + return &I; + + // Do this before using distributive laws to catch simple and/or/not patterns. + if (Instruction *Xor = foldOrToXor(I, Builder)) + return Xor; + + if (Instruction *X = foldComplexAndOrPatterns(I, Builder)) + return X; + + // (A&B)|(A&C) -> A&(B|C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return replaceInstUsesWith(I, V); + + if (Value *V = SimplifyBSwap(I, Builder)) + return replaceInstUsesWith(I, V); + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); + if (Ty->isIntOrIntVectorTy(1)) { + if (auto *SI0 = dyn_cast<SelectInst>(Op0)) { + if (auto *I = + foldAndOrOfSelectUsingImpliedCond(Op1, *SI0, /* IsAnd */ false)) + return I; + } + if (auto *SI1 = dyn_cast<SelectInst>(Op1)) { + if (auto *I = + foldAndOrOfSelectUsingImpliedCond(Op0, *SI1, /* IsAnd */ false)) + return I; + } + } + + if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) + return FoldedLogic; + + if (Instruction *BitOp = matchBSwapOrBitReverse(I, /*MatchBSwaps*/ true, + /*MatchBitReversals*/ true)) + return BitOp; + + if (Instruction *Funnel = matchFunnelShift(I, *this)) + return Funnel; + + if (Instruction *Concat = matchOrConcat(I, Builder)) + return replaceInstUsesWith(I, Concat); + + Value *X, *Y; + const APInt *CV; + if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) && + !CV->isAllOnes() && MaskedValueIsZero(Y, *CV, 0, &I)) { + // (X ^ C) | Y -> (X | Y) ^ C iff Y & C == 0 + // The check for a 'not' op is for efficiency (if Y is known zero --> ~X). + Value *Or = Builder.CreateOr(X, Y); + return BinaryOperator::CreateXor(Or, ConstantInt::get(Ty, *CV)); + } + + // If the operands have no common bits set: + // or (mul X, Y), X --> add (mul X, Y), X --> mul X, (Y + 1) + if (match(&I, + m_c_Or(m_OneUse(m_Mul(m_Value(X), m_Value(Y))), m_Deferred(X))) && + haveNoCommonBitsSet(Op0, Op1, DL)) { + Value *IncrementY = Builder.CreateAdd(Y, ConstantInt::get(Ty, 1)); + return BinaryOperator::CreateMul(X, IncrementY); + } + + // (A & C) | (B & D) + Value *A, *B, *C, *D; + if (match(Op0, m_And(m_Value(A), m_Value(C))) && + match(Op1, m_And(m_Value(B), m_Value(D)))) { + + // (A & C0) | (B & C1) + const APInt *C0, *C1; + if (match(C, m_APInt(C0)) && match(D, m_APInt(C1))) { + Value *X; + if (*C0 == ~*C1) { + // ((X | B) & MaskC) | (B & ~MaskC) -> (X & MaskC) | B + if (match(A, m_c_Or(m_Value(X), m_Specific(B)))) + return BinaryOperator::CreateOr(Builder.CreateAnd(X, *C0), B); + // (A & MaskC) | ((X | A) & ~MaskC) -> (X & ~MaskC) | A + if (match(B, m_c_Or(m_Specific(A), m_Value(X)))) + return BinaryOperator::CreateOr(Builder.CreateAnd(X, *C1), A); + + // ((X ^ B) & MaskC) | (B & ~MaskC) -> (X & MaskC) ^ B + if (match(A, m_c_Xor(m_Value(X), m_Specific(B)))) + return BinaryOperator::CreateXor(Builder.CreateAnd(X, *C0), B); + // (A & MaskC) | ((X ^ A) & ~MaskC) -> (X & ~MaskC) ^ A + if (match(B, m_c_Xor(m_Specific(A), m_Value(X)))) + return BinaryOperator::CreateXor(Builder.CreateAnd(X, *C1), A); + } + + if ((*C0 & *C1).isZero()) { + // ((X | B) & C0) | (B & C1) --> (X | B) & (C0 | C1) + // iff (C0 & C1) == 0 and (X & ~C0) == 0 + if (match(A, m_c_Or(m_Value(X), m_Specific(B))) && + MaskedValueIsZero(X, ~*C0, 0, &I)) { + Constant *C01 = ConstantInt::get(Ty, *C0 | *C1); + return BinaryOperator::CreateAnd(A, C01); + } + // (A & C0) | ((X | A) & C1) --> (X | A) & (C0 | C1) + // iff (C0 & C1) == 0 and (X & ~C1) == 0 + if (match(B, m_c_Or(m_Value(X), m_Specific(A))) && + MaskedValueIsZero(X, ~*C1, 0, &I)) { + Constant *C01 = ConstantInt::get(Ty, *C0 | *C1); + return BinaryOperator::CreateAnd(B, C01); + } + // ((X | C2) & C0) | ((X | C3) & C1) --> (X | C2 | C3) & (C0 | C1) + // iff (C0 & C1) == 0 and (C2 & ~C0) == 0 and (C3 & ~C1) == 0. + const APInt *C2, *C3; + if (match(A, m_Or(m_Value(X), m_APInt(C2))) && + match(B, m_Or(m_Specific(X), m_APInt(C3))) && + (*C2 & ~*C0).isZero() && (*C3 & ~*C1).isZero()) { + Value *Or = Builder.CreateOr(X, *C2 | *C3, "bitfield"); + Constant *C01 = ConstantInt::get(Ty, *C0 | *C1); + return BinaryOperator::CreateAnd(Or, C01); + } + } + } + + // Don't try to form a select if it's unlikely that we'll get rid of at + // least one of the operands. A select is generally more expensive than the + // 'or' that it is replacing. + if (Op0->hasOneUse() || Op1->hasOneUse()) { + // (Cond & C) | (~Cond & D) -> Cond ? C : D, and commuted variants. + if (Value *V = matchSelectFromAndOr(A, C, B, D)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(A, C, D, B)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(C, A, B, D)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(C, A, D, B)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(B, D, A, C)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(B, D, C, A)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(D, B, A, C)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(D, B, C, A)) + return replaceInstUsesWith(I, V); + } + } + + // (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) + if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) + return BinaryOperator::CreateOr(Op0, C); + + // ((A ^ C) ^ B) | (B ^ A) -> (B ^ A) | C + if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) + if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) + return BinaryOperator::CreateOr(Op1, C); + + // ((A & B) ^ C) | B -> C | B + if (match(Op0, m_c_Xor(m_c_And(m_Value(A), m_Specific(Op1)), m_Value(C)))) + return BinaryOperator::CreateOr(C, Op1); + + // B | ((A & B) ^ C) -> B | C + if (match(Op1, m_c_Xor(m_c_And(m_Value(A), m_Specific(Op0)), m_Value(C)))) + return BinaryOperator::CreateOr(Op0, C); + + // ((B | C) & A) | B -> B | (A & C) + if (match(Op0, m_And(m_Or(m_Specific(Op1), m_Value(C)), m_Value(A)))) + return BinaryOperator::CreateOr(Op1, Builder.CreateAnd(A, C)); + + if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) + return DeMorgan; + + // Canonicalize xor to the RHS. + bool SwappedForXor = false; + if (match(Op0, m_Xor(m_Value(), m_Value()))) { + std::swap(Op0, Op1); + SwappedForXor = true; + } + + // A | ( A ^ B) -> A | B + // A | (~A ^ B) -> A | ~B + // (A & B) | (A ^ B) + // ~A | (A ^ B) -> ~(A & B) + // The swap above should always make Op0 the 'not' for the last case. + if (match(Op1, m_Xor(m_Value(A), m_Value(B)))) { + if (Op0 == A || Op0 == B) + return BinaryOperator::CreateOr(A, B); + + if (match(Op0, m_And(m_Specific(A), m_Specific(B))) || + match(Op0, m_And(m_Specific(B), m_Specific(A)))) + return BinaryOperator::CreateOr(A, B); + + if ((Op0->hasOneUse() || Op1->hasOneUse()) && + (match(Op0, m_Not(m_Specific(A))) || match(Op0, m_Not(m_Specific(B))))) + return BinaryOperator::CreateNot(Builder.CreateAnd(A, B)); + + if (Op1->hasOneUse() && match(A, m_Not(m_Specific(Op0)))) { + Value *Not = Builder.CreateNot(B, B->getName() + ".not"); + return BinaryOperator::CreateOr(Not, Op0); + } + if (Op1->hasOneUse() && match(B, m_Not(m_Specific(Op0)))) { + Value *Not = Builder.CreateNot(A, A->getName() + ".not"); + return BinaryOperator::CreateOr(Not, Op0); + } + } + + // A | ~(A | B) -> A | ~B + // A | ~(A ^ B) -> A | ~B + if (match(Op1, m_Not(m_Value(A)))) + if (BinaryOperator *B = dyn_cast<BinaryOperator>(A)) + if ((Op0 == B->getOperand(0) || Op0 == B->getOperand(1)) && + Op1->hasOneUse() && (B->getOpcode() == Instruction::Or || + B->getOpcode() == Instruction::Xor)) { + Value *NotOp = Op0 == B->getOperand(0) ? B->getOperand(1) : + B->getOperand(0); + Value *Not = Builder.CreateNot(NotOp, NotOp->getName() + ".not"); + return BinaryOperator::CreateOr(Not, Op0); + } + + if (SwappedForXor) + std::swap(Op0, Op1); + + { + ICmpInst *LHS = dyn_cast<ICmpInst>(Op0); + ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); + if (LHS && RHS) + if (Value *Res = foldAndOrOfICmps(LHS, RHS, I, /* IsAnd */ false)) + return replaceInstUsesWith(I, Res); + + // TODO: Make this recursive; it's a little tricky because an arbitrary + // number of 'or' instructions might have to be created. + Value *X, *Y; + if (LHS && match(Op1, m_OneUse(m_LogicalOr(m_Value(X), m_Value(Y))))) { + bool IsLogical = isa<SelectInst>(Op1); + // LHS | (X || Y) --> (LHS || X) || Y + if (auto *Cmp = dyn_cast<ICmpInst>(X)) + if (Value *Res = + foldAndOrOfICmps(LHS, Cmp, I, /* IsAnd */ false, IsLogical)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalOr(Res, Y) + : Builder.CreateOr(Res, Y)); + // LHS | (X || Y) --> X || (LHS | Y) + if (auto *Cmp = dyn_cast<ICmpInst>(Y)) + if (Value *Res = foldAndOrOfICmps(LHS, Cmp, I, /* IsAnd */ false, + /* IsLogical */ false)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalOr(X, Res) + : Builder.CreateOr(X, Res)); + } + if (RHS && match(Op0, m_OneUse(m_LogicalOr(m_Value(X), m_Value(Y))))) { + bool IsLogical = isa<SelectInst>(Op0); + // (X || Y) | RHS --> (X || RHS) || Y + if (auto *Cmp = dyn_cast<ICmpInst>(X)) + if (Value *Res = + foldAndOrOfICmps(Cmp, RHS, I, /* IsAnd */ false, IsLogical)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalOr(Res, Y) + : Builder.CreateOr(Res, Y)); + // (X || Y) | RHS --> X || (Y | RHS) + if (auto *Cmp = dyn_cast<ICmpInst>(Y)) + if (Value *Res = foldAndOrOfICmps(Cmp, RHS, I, /* IsAnd */ false, + /* IsLogical */ false)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalOr(X, Res) + : Builder.CreateOr(X, Res)); + } + } + + if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) + if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) + if (Value *Res = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false)) + return replaceInstUsesWith(I, Res); + + if (Instruction *FoldedFCmps = reassociateFCmps(I, Builder)) + return FoldedFCmps; + + if (Instruction *CastedOr = foldCastedBitwiseLogic(I)) + return CastedOr; + + if (Instruction *Sel = foldBinopOfSextBoolToSelect(I)) + return Sel; + + // or(sext(A), B) / or(B, sext(A)) --> A ? -1 : B, where A is i1 or <N x i1>. + // TODO: Move this into foldBinopOfSextBoolToSelect as a more generalized fold + // with binop identity constant. But creating a select with non-constant + // arm may not be reversible due to poison semantics. Is that a good + // canonicalization? + if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, ConstantInt::getAllOnesValue(Ty), Op1); + if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, ConstantInt::getAllOnesValue(Ty), Op0); + + // Note: If we've gotten to the point of visiting the outer OR, then the + // inner one couldn't be simplified. If it was a constant, then it won't + // be simplified by a later pass either, so we try swapping the inner/outer + // ORs in the hopes that we'll be able to simplify it this way. + // (X|C) | V --> (X|V) | C + ConstantInt *CI; + if (Op0->hasOneUse() && !match(Op1, m_ConstantInt()) && + match(Op0, m_Or(m_Value(A), m_ConstantInt(CI)))) { + Value *Inner = Builder.CreateOr(A, Op1); + Inner->takeName(Op0); + return BinaryOperator::CreateOr(Inner, CI); + } + + // Change (or (bool?A:B),(bool?C:D)) --> (bool?(or A,C):(or B,D)) + // Since this OR statement hasn't been optimized further yet, we hope + // that this transformation will allow the new ORs to be optimized. + { + Value *X = nullptr, *Y = nullptr; + if (Op0->hasOneUse() && Op1->hasOneUse() && + match(Op0, m_Select(m_Value(X), m_Value(A), m_Value(B))) && + match(Op1, m_Select(m_Value(Y), m_Value(C), m_Value(D))) && X == Y) { + Value *orTrue = Builder.CreateOr(A, C); + Value *orFalse = Builder.CreateOr(B, D); + return SelectInst::Create(X, orTrue, orFalse); + } + } + + // or(ashr(subNSW(Y, X), ScalarSizeInBits(Y) - 1), X) --> X s> Y ? -1 : X. + { + Value *X, *Y; + if (match(&I, m_c_Or(m_OneUse(m_AShr( + m_NSWSub(m_Value(Y), m_Value(X)), + m_SpecificInt(Ty->getScalarSizeInBits() - 1))), + m_Deferred(X)))) { + Value *NewICmpInst = Builder.CreateICmpSGT(X, Y); + Value *AllOnes = ConstantInt::getAllOnesValue(Ty); + return SelectInst::Create(NewICmpInst, AllOnes, X); + } + } + + if (Instruction *V = + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) + return V; + + CmpInst::Predicate Pred; + Value *Mul, *Ov, *MulIsNotZero, *UMulWithOv; + // Check if the OR weakens the overflow condition for umul.with.overflow by + // treating any non-zero result as overflow. In that case, we overflow if both + // umul.with.overflow operands are != 0, as in that case the result can only + // be 0, iff the multiplication overflows. + if (match(&I, + m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_Value(UMulWithOv)), + m_Value(Ov)), + m_CombineAnd(m_ICmp(Pred, + m_CombineAnd(m_ExtractValue<0>( + m_Deferred(UMulWithOv)), + m_Value(Mul)), + m_ZeroInt()), + m_Value(MulIsNotZero)))) && + (Ov->hasOneUse() || (MulIsNotZero->hasOneUse() && Mul->hasOneUse())) && + Pred == CmpInst::ICMP_NE) { + Value *A, *B; + if (match(UMulWithOv, m_Intrinsic<Intrinsic::umul_with_overflow>( + m_Value(A), m_Value(B)))) { + Value *NotNullA = Builder.CreateIsNotNull(A); + Value *NotNullB = Builder.CreateIsNotNull(B); + return BinaryOperator::CreateAnd(NotNullA, NotNullB); + } + } + + // (~x) | y --> ~(x & (~y)) iff that gets rid of inversions + if (sinkNotIntoOtherHandOfAndOrOr(I)) + return &I; + + // Improve "get low bit mask up to and including bit X" pattern: + // (1 << X) | ((1 << X) + -1) --> -1 l>> (bitwidth(x) - 1 - X) + if (match(&I, m_c_Or(m_Add(m_Shl(m_One(), m_Value(X)), m_AllOnes()), + m_Shl(m_One(), m_Deferred(X)))) && + match(&I, m_c_Or(m_OneUse(m_Value()), m_Value()))) { + Value *Sub = Builder.CreateSub( + ConstantInt::get(Ty, Ty->getScalarSizeInBits() - 1), X); + return BinaryOperator::CreateLShr(Constant::getAllOnesValue(Ty), Sub); + } + + // An or recurrence w/loop invariant step is equivelent to (or start, step) + PHINode *PN = nullptr; + Value *Start = nullptr, *Step = nullptr; + if (matchSimpleRecurrence(&I, PN, Start, Step) && DT.dominates(Step, PN)) + return replaceInstUsesWith(I, Builder.CreateOr(Start, Step)); + + // (A & B) | (C | D) or (C | D) | (A & B) + // Can be combined if C or D is of type (A/B & X) + if (match(&I, m_c_Or(m_OneUse(m_And(m_Value(A), m_Value(B))), + m_OneUse(m_Or(m_Value(C), m_Value(D)))))) { + // (A & B) | (C | ?) -> C | (? | (A & B)) + // (A & B) | (C | ?) -> C | (? | (A & B)) + // (A & B) | (C | ?) -> C | (? | (A & B)) + // (A & B) | (C | ?) -> C | (? | (A & B)) + // (C | ?) | (A & B) -> C | (? | (A & B)) + // (C | ?) | (A & B) -> C | (? | (A & B)) + // (C | ?) | (A & B) -> C | (? | (A & B)) + // (C | ?) | (A & B) -> C | (? | (A & B)) + if (match(D, m_OneUse(m_c_And(m_Specific(A), m_Value()))) || + match(D, m_OneUse(m_c_And(m_Specific(B), m_Value())))) + return BinaryOperator::CreateOr( + C, Builder.CreateOr(D, Builder.CreateAnd(A, B))); + // (A & B) | (? | D) -> (? | (A & B)) | D + // (A & B) | (? | D) -> (? | (A & B)) | D + // (A & B) | (? | D) -> (? | (A & B)) | D + // (A & B) | (? | D) -> (? | (A & B)) | D + // (? | D) | (A & B) -> (? | (A & B)) | D + // (? | D) | (A & B) -> (? | (A & B)) | D + // (? | D) | (A & B) -> (? | (A & B)) | D + // (? | D) | (A & B) -> (? | (A & B)) | D + if (match(C, m_OneUse(m_c_And(m_Specific(A), m_Value()))) || + match(C, m_OneUse(m_c_And(m_Specific(B), m_Value())))) + return BinaryOperator::CreateOr( + Builder.CreateOr(C, Builder.CreateAnd(A, B)), D); + } + + return nullptr; +} + +/// A ^ B can be specified using other logic ops in a variety of patterns. We +/// can fold these early and efficiently by morphing an existing instruction. +static Instruction *foldXorToXor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.getOpcode() == Instruction::Xor); + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + Value *A, *B; + + // There are 4 commuted variants for each of the basic patterns. + + // (A & B) ^ (A | B) -> A ^ B + // (A & B) ^ (B | A) -> A ^ B + // (A | B) ^ (A & B) -> A ^ B + // (A | B) ^ (B & A) -> A ^ B + if (match(&I, m_c_Xor(m_And(m_Value(A), m_Value(B)), + m_c_Or(m_Deferred(A), m_Deferred(B))))) + return BinaryOperator::CreateXor(A, B); + + // (A | ~B) ^ (~A | B) -> A ^ B + // (~B | A) ^ (~A | B) -> A ^ B + // (~A | B) ^ (A | ~B) -> A ^ B + // (B | ~A) ^ (A | ~B) -> A ^ B + if (match(&I, m_Xor(m_c_Or(m_Value(A), m_Not(m_Value(B))), + m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) + return BinaryOperator::CreateXor(A, B); + + // (A & ~B) ^ (~A & B) -> A ^ B + // (~B & A) ^ (~A & B) -> A ^ B + // (~A & B) ^ (A & ~B) -> A ^ B + // (B & ~A) ^ (A & ~B) -> A ^ B + if (match(&I, m_Xor(m_c_And(m_Value(A), m_Not(m_Value(B))), + m_c_And(m_Not(m_Deferred(A)), m_Deferred(B))))) + return BinaryOperator::CreateXor(A, B); + + // For the remaining cases we need to get rid of one of the operands. + if (!Op0->hasOneUse() && !Op1->hasOneUse()) + return nullptr; + + // (A | B) ^ ~(A & B) -> ~(A ^ B) + // (A | B) ^ ~(B & A) -> ~(A ^ B) + // (A & B) ^ ~(A | B) -> ~(A ^ B) + // (A & B) ^ ~(B | A) -> ~(A ^ B) + // Complexity sorting ensures the not will be on the right side. + if ((match(Op0, m_Or(m_Value(A), m_Value(B))) && + match(Op1, m_Not(m_c_And(m_Specific(A), m_Specific(B))))) || + (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))) + return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); + + return nullptr; +} + +Value *InstCombinerImpl::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, + BinaryOperator &I) { + assert(I.getOpcode() == Instruction::Xor && I.getOperand(0) == LHS && + I.getOperand(1) == RHS && "Should be 'xor' with these operands"); + + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); + Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); + + if (predicatesFoldable(PredL, PredR)) { + if (LHS0 == RHS1 && LHS1 == RHS0) { + std::swap(LHS0, LHS1); + PredL = ICmpInst::getSwappedPredicate(PredL); + } + if (LHS0 == RHS0 && LHS1 == RHS1) { + // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) + unsigned Code = getICmpCode(PredL) ^ getICmpCode(PredR); + bool IsSigned = LHS->isSigned() || RHS->isSigned(); + return getNewICmpValue(Code, IsSigned, LHS0, LHS1, Builder); + } + } + + // TODO: This can be generalized to compares of non-signbits using + // decomposeBitTestICmp(). It could be enhanced more by using (something like) + // foldLogOpOfMaskedICmps(). + if ((LHS->hasOneUse() || RHS->hasOneUse()) && + LHS0->getType() == RHS0->getType() && + LHS0->getType()->isIntOrIntVectorTy()) { + // (X > -1) ^ (Y > -1) --> (X ^ Y) < 0 + // (X < 0) ^ (Y < 0) --> (X ^ Y) < 0 + if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_AllOnes()) && + PredR == CmpInst::ICMP_SGT && match(RHS1, m_AllOnes())) || + (PredL == CmpInst::ICMP_SLT && match(LHS1, m_Zero()) && + PredR == CmpInst::ICMP_SLT && match(RHS1, m_Zero()))) + return Builder.CreateIsNeg(Builder.CreateXor(LHS0, RHS0)); + + // (X > -1) ^ (Y < 0) --> (X ^ Y) > -1 + // (X < 0) ^ (Y > -1) --> (X ^ Y) > -1 + if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_AllOnes()) && + PredR == CmpInst::ICMP_SLT && match(RHS1, m_Zero())) || + (PredL == CmpInst::ICMP_SLT && match(LHS1, m_Zero()) && + PredR == CmpInst::ICMP_SGT && match(RHS1, m_AllOnes()))) + return Builder.CreateIsNotNeg(Builder.CreateXor(LHS0, RHS0)); + + } + + // Instead of trying to imitate the folds for and/or, decompose this 'xor' + // into those logic ops. That is, try to turn this into an and-of-icmps + // because we have many folds for that pattern. + // + // This is based on a truth table definition of xor: + // X ^ Y --> (X | Y) & !(X & Y) + if (Value *OrICmp = simplifyBinOp(Instruction::Or, LHS, RHS, SQ)) { + // TODO: If OrICmp is true, then the definition of xor simplifies to !(X&Y). + // TODO: If OrICmp is false, the whole thing is false (InstSimplify?). + if (Value *AndICmp = simplifyBinOp(Instruction::And, LHS, RHS, SQ)) { + // TODO: Independently handle cases where the 'and' side is a constant. + ICmpInst *X = nullptr, *Y = nullptr; + if (OrICmp == LHS && AndICmp == RHS) { + // (LHS | RHS) & !(LHS & RHS) --> LHS & !RHS --> X & !Y + X = LHS; + Y = RHS; + } + if (OrICmp == RHS && AndICmp == LHS) { + // !(LHS & RHS) & (LHS | RHS) --> !LHS & RHS --> !Y & X + X = RHS; + Y = LHS; + } + if (X && Y && (Y->hasOneUse() || canFreelyInvertAllUsersOf(Y, &I))) { + // Invert the predicate of 'Y', thus inverting its output. + Y->setPredicate(Y->getInversePredicate()); + // So, are there other uses of Y? + if (!Y->hasOneUse()) { + // We need to adapt other uses of Y though. Get a value that matches + // the original value of Y before inversion. While this increases + // immediate instruction count, we have just ensured that all the + // users are freely-invertible, so that 'not' *will* get folded away. + BuilderTy::InsertPointGuard Guard(Builder); + // Set insertion point to right after the Y. + Builder.SetInsertPoint(Y->getParent(), ++(Y->getIterator())); + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + // Replace all uses of Y (excluding the one in NotY!) with NotY. + Worklist.pushUsersToWorkList(*Y); + Y->replaceUsesWithIf(NotY, + [NotY](Use &U) { return U.getUser() != NotY; }); + } + // All done. + return Builder.CreateAnd(LHS, RHS); + } + } + } + + return nullptr; +} + +/// If we have a masked merge, in the canonical form of: +/// (assuming that A only has one use.) +/// | A | |B| +/// ((x ^ y) & M) ^ y +/// | D | +/// * If M is inverted: +/// | D | +/// ((x ^ y) & ~M) ^ y +/// We can canonicalize by swapping the final xor operand +/// to eliminate the 'not' of the mask. +/// ((x ^ y) & M) ^ x +/// * If M is a constant, and D has one use, we transform to 'and' / 'or' ops +/// because that shortens the dependency chain and improves analysis: +/// (x & M) | (y & ~M) +static Instruction *visitMaskedMerge(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *B, *X, *D; + Value *M; + if (!match(&I, m_c_Xor(m_Value(B), + m_OneUse(m_c_And( + m_CombineAnd(m_c_Xor(m_Deferred(B), m_Value(X)), + m_Value(D)), + m_Value(M)))))) + return nullptr; + + Value *NotM; + if (match(M, m_Not(m_Value(NotM)))) { + // De-invert the mask and swap the value in B part. + Value *NewA = Builder.CreateAnd(D, NotM); + return BinaryOperator::CreateXor(NewA, X); + } + + Constant *C; + if (D->hasOneUse() && match(M, m_Constant(C))) { + // Propagating undef is unsafe. Clamp undef elements to -1. + Type *EltTy = C->getType()->getScalarType(); + C = Constant::replaceUndefsWith(C, ConstantInt::getAllOnesValue(EltTy)); + // Unfold. + Value *LHS = Builder.CreateAnd(X, C); + Value *NotC = Builder.CreateNot(C); + Value *RHS = Builder.CreateAnd(B, NotC); + return BinaryOperator::CreateOr(LHS, RHS); + } + + return nullptr; +} + +// Transform +// ~(x ^ y) +// into: +// (~x) ^ y +// or into +// x ^ (~y) +static Instruction *sinkNotIntoXor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *X, *Y; + // FIXME: one-use check is not needed in general, but currently we are unable + // to fold 'not' into 'icmp', if that 'icmp' has multiple uses. (D35182) + if (!match(&I, m_Not(m_OneUse(m_Xor(m_Value(X), m_Value(Y)))))) + return nullptr; + + // We only want to do the transform if it is free to do. + if (InstCombiner::isFreeToInvert(X, X->hasOneUse())) { + // Ok, good. + } else if (InstCombiner::isFreeToInvert(Y, Y->hasOneUse())) { + std::swap(X, Y); + } else + return nullptr; + + Value *NotX = Builder.CreateNot(X, X->getName() + ".not"); + return BinaryOperator::CreateXor(NotX, Y, I.getName() + ".demorgan"); +} + +/// Canonicalize a shifty way to code absolute value to the more common pattern +/// that uses negation and select. +static Instruction *canonicalizeAbs(BinaryOperator &Xor, + InstCombiner::BuilderTy &Builder) { + assert(Xor.getOpcode() == Instruction::Xor && "Expected an xor instruction."); + + // There are 4 potential commuted variants. Move the 'ashr' candidate to Op1. + // We're relying on the fact that we only do this transform when the shift has + // exactly 2 uses and the add has exactly 1 use (otherwise, we might increase + // instructions). + Value *Op0 = Xor.getOperand(0), *Op1 = Xor.getOperand(1); + if (Op0->hasNUses(2)) + std::swap(Op0, Op1); + + Type *Ty = Xor.getType(); + Value *A; + const APInt *ShAmt; + if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && + Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 && + match(Op0, m_OneUse(m_c_Add(m_Specific(A), m_Specific(Op1))))) { + // Op1 = ashr i32 A, 31 ; smear the sign bit + // xor (add A, Op1), Op1 ; add -1 and flip bits if negative + // --> (A < 0) ? -A : A + Value *IsNeg = Builder.CreateIsNeg(A); + // Copy the nuw/nsw flags from the add to the negate. + auto *Add = cast<BinaryOperator>(Op0); + Value *NegA = Builder.CreateNeg(A, "", Add->hasNoUnsignedWrap(), + Add->hasNoSignedWrap()); + return SelectInst::Create(IsNeg, NegA, A); + } + return nullptr; +} + +// Transform +// z = (~x) &/| y +// into: +// z = ~(x |/& (~y)) +// iff y is free to invert and all uses of z can be freely updated. +bool InstCombinerImpl::sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I) { + Instruction::BinaryOps NewOpc; + switch (I.getOpcode()) { + case Instruction::And: + NewOpc = Instruction::Or; + break; + case Instruction::Or: + NewOpc = Instruction::And; + break; + default: + return false; + }; + + Value *X, *Y; + if (!match(&I, m_c_BinOp(m_Not(m_Value(X)), m_Value(Y)))) + return false; + + // Will we be able to fold the `not` into Y eventually? + if (!InstCombiner::isFreeToInvert(Y, Y->hasOneUse())) + return false; + + // And can our users be adapted? + if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) + return false; + + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + Value *NewBinOp = + BinaryOperator::Create(NewOpc, X, NotY, I.getName() + ".not"); + Builder.Insert(NewBinOp); + replaceInstUsesWith(I, NewBinOp); + // We can not just create an outer `not`, it will most likely be immediately + // folded back, reconstructing our initial pattern, and causing an + // infinite combine loop, so immediately manually fold it away. + freelyInvertAllUsersOf(NewBinOp); + return true; +} + +Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { + Value *NotOp; + if (!match(&I, m_Not(m_Value(NotOp)))) + return nullptr; + + // Apply DeMorgan's Law for 'nand' / 'nor' logic with an inverted operand. + // We must eliminate the and/or (one-use) for these transforms to not increase + // the instruction count. + // + // ~(~X & Y) --> (X | ~Y) + // ~(Y & ~X) --> (X | ~Y) + // + // Note: The logical matches do not check for the commuted patterns because + // those are handled via SimplifySelectsFeedingBinaryOp(). + Type *Ty = I.getType(); + Value *X, *Y; + if (match(NotOp, m_OneUse(m_c_And(m_Not(m_Value(X)), m_Value(Y))))) { + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + return BinaryOperator::CreateOr(X, NotY); + } + if (match(NotOp, m_OneUse(m_LogicalAnd(m_Not(m_Value(X)), m_Value(Y))))) { + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + return SelectInst::Create(X, ConstantInt::getTrue(Ty), NotY); + } + + // ~(~X | Y) --> (X & ~Y) + // ~(Y | ~X) --> (X & ~Y) + if (match(NotOp, m_OneUse(m_c_Or(m_Not(m_Value(X)), m_Value(Y))))) { + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + return BinaryOperator::CreateAnd(X, NotY); + } + if (match(NotOp, m_OneUse(m_LogicalOr(m_Not(m_Value(X)), m_Value(Y))))) { + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + return SelectInst::Create(X, NotY, ConstantInt::getFalse(Ty)); + } + + // Is this a 'not' (~) fed by a binary operator? + BinaryOperator *NotVal; + if (match(NotOp, m_BinOp(NotVal))) { + if (NotVal->getOpcode() == Instruction::And || + NotVal->getOpcode() == Instruction::Or) { + // Apply DeMorgan's Law when inverts are free: + // ~(X & Y) --> (~X | ~Y) + // ~(X | Y) --> (~X & ~Y) + if (isFreeToInvert(NotVal->getOperand(0), + NotVal->getOperand(0)->hasOneUse()) && + isFreeToInvert(NotVal->getOperand(1), + NotVal->getOperand(1)->hasOneUse())) { + Value *NotX = Builder.CreateNot(NotVal->getOperand(0), "notlhs"); + Value *NotY = Builder.CreateNot(NotVal->getOperand(1), "notrhs"); + if (NotVal->getOpcode() == Instruction::And) + return BinaryOperator::CreateOr(NotX, NotY); + return BinaryOperator::CreateAnd(NotX, NotY); + } + } + + // ~((-X) | Y) --> (X - 1) & (~Y) + if (match(NotVal, + m_OneUse(m_c_Or(m_OneUse(m_Neg(m_Value(X))), m_Value(Y))))) { + Value *DecX = Builder.CreateAdd(X, ConstantInt::getAllOnesValue(Ty)); + Value *NotY = Builder.CreateNot(Y); + return BinaryOperator::CreateAnd(DecX, NotY); + } + + // ~(~X >>s Y) --> (X >>s Y) + if (match(NotVal, m_AShr(m_Not(m_Value(X)), m_Value(Y)))) + return BinaryOperator::CreateAShr(X, Y); + + // If we are inverting a right-shifted constant, we may be able to eliminate + // the 'not' by inverting the constant and using the opposite shift type. + // Canonicalization rules ensure that only a negative constant uses 'ashr', + // but we must check that in case that transform has not fired yet. + + // ~(C >>s Y) --> ~C >>u Y (when inverting the replicated sign bits) + Constant *C; + if (match(NotVal, m_AShr(m_Constant(C), m_Value(Y))) && + match(C, m_Negative())) { + // We matched a negative constant, so propagating undef is unsafe. + // Clamp undef elements to -1. + Type *EltTy = Ty->getScalarType(); + C = Constant::replaceUndefsWith(C, ConstantInt::getAllOnesValue(EltTy)); + return BinaryOperator::CreateLShr(ConstantExpr::getNot(C), Y); + } + + // ~(C >>u Y) --> ~C >>s Y (when inverting the replicated sign bits) + if (match(NotVal, m_LShr(m_Constant(C), m_Value(Y))) && + match(C, m_NonNegative())) { + // We matched a non-negative constant, so propagating undef is unsafe. + // Clamp undef elements to 0. + Type *EltTy = Ty->getScalarType(); + C = Constant::replaceUndefsWith(C, ConstantInt::getNullValue(EltTy)); + return BinaryOperator::CreateAShr(ConstantExpr::getNot(C), Y); + } + + // ~(X + C) --> ~C - X + if (match(NotVal, m_c_Add(m_Value(X), m_ImmConstant(C)))) + return BinaryOperator::CreateSub(ConstantExpr::getNot(C), X); + + // ~(X - Y) --> ~X + Y + // FIXME: is it really beneficial to sink the `not` here? + if (match(NotVal, m_Sub(m_Value(X), m_Value(Y)))) + if (isa<Constant>(X) || NotVal->hasOneUse()) + return BinaryOperator::CreateAdd(Builder.CreateNot(X), Y); + + // ~(~X + Y) --> X - Y + if (match(NotVal, m_c_Add(m_Not(m_Value(X)), m_Value(Y)))) + return BinaryOperator::CreateWithCopiedFlags(Instruction::Sub, X, Y, + NotVal); + } + + // not (cmp A, B) = !cmp A, B + CmpInst::Predicate Pred; + if (match(NotOp, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) { + cast<CmpInst>(NotOp)->setPredicate(CmpInst::getInversePredicate(Pred)); + return replaceInstUsesWith(I, NotOp); + } + + // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: + // ~min(~X, ~Y) --> max(X, Y) + // ~max(~X, Y) --> min(X, ~Y) + auto *II = dyn_cast<IntrinsicInst>(NotOp); + if (II && II->hasOneUse()) { + if (match(NotOp, m_MaxOrMin(m_Value(X), m_Value(Y))) && + isFreeToInvert(X, X->hasOneUse()) && + isFreeToInvert(Y, Y->hasOneUse())) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); + Value *NotX = Builder.CreateNot(X); + Value *NotY = Builder.CreateNot(Y); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, NotX, NotY); + return replaceInstUsesWith(I, InvMaxMin); + } + if (match(NotOp, m_c_MaxOrMin(m_Not(m_Value(X)), m_Value(Y)))) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); + Value *NotY = Builder.CreateNot(Y); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, NotY); + return replaceInstUsesWith(I, InvMaxMin); + } + } + + if (NotOp->hasOneUse()) { + // Pull 'not' into operands of select if both operands are one-use compares + // or one is one-use compare and the other one is a constant. + // Inverting the predicates eliminates the 'not' operation. + // Example: + // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) --> + // select ?, (cmp InvTPred, ?, ?), (cmp InvFPred, ?, ?) + // not (select ?, (cmp TPred, ?, ?), true --> + // select ?, (cmp InvTPred, ?, ?), false + if (auto *Sel = dyn_cast<SelectInst>(NotOp)) { + Value *TV = Sel->getTrueValue(); + Value *FV = Sel->getFalseValue(); + auto *CmpT = dyn_cast<CmpInst>(TV); + auto *CmpF = dyn_cast<CmpInst>(FV); + bool InvertibleT = (CmpT && CmpT->hasOneUse()) || isa<Constant>(TV); + bool InvertibleF = (CmpF && CmpF->hasOneUse()) || isa<Constant>(FV); + if (InvertibleT && InvertibleF) { + if (CmpT) + CmpT->setPredicate(CmpT->getInversePredicate()); + else + Sel->setTrueValue(ConstantExpr::getNot(cast<Constant>(TV))); + if (CmpF) + CmpF->setPredicate(CmpF->getInversePredicate()); + else + Sel->setFalseValue(ConstantExpr::getNot(cast<Constant>(FV))); + return replaceInstUsesWith(I, Sel); + } + } + } + + if (Instruction *NewXor = sinkNotIntoXor(I, Builder)) + return NewXor; + + return nullptr; +} + +// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches +// here. We should standardize that construct where it is needed or choose some +// other way to ensure that commutated variants of patterns are not missed. +Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { + if (Value *V = simplifyXorInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + if (Instruction *NewXor = foldXorToXor(I, Builder)) + return NewXor; + + // (A&B)^(A&C) -> A&(B^C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return replaceInstUsesWith(I, V); + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + if (SimplifyDemandedInstructionBits(I)) + return &I; + + if (Value *V = SimplifyBSwap(I, Builder)) + return replaceInstUsesWith(I, V); + + if (Instruction *R = foldNot(I)) + return R; + + // Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M) + // This it a special case in haveNoCommonBitsSet, but the computeKnownBits + // calls in there are unnecessary as SimplifyDemandedInstructionBits should + // have already taken care of those cases. + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *M; + if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()), + m_c_And(m_Deferred(M), m_Value())))) + return BinaryOperator::CreateOr(Op0, Op1); + + if (Instruction *Xor = visitMaskedMerge(I, Builder)) + return Xor; + + Value *X, *Y; + Constant *C1; + if (match(Op1, m_Constant(C1))) { + Constant *C2; + + if (match(Op0, m_OneUse(m_Or(m_Value(X), m_ImmConstant(C2)))) && + match(C1, m_ImmConstant())) { + // (X | C2) ^ C1 --> (X & ~C2) ^ (C1^C2) + C2 = Constant::replaceUndefsWith( + C2, Constant::getAllOnesValue(C2->getType()->getScalarType())); + Value *And = Builder.CreateAnd( + X, Constant::mergeUndefsWith(ConstantExpr::getNot(C2), C1)); + return BinaryOperator::CreateXor( + And, Constant::mergeUndefsWith(ConstantExpr::getXor(C1, C2), C1)); + } + + // Use DeMorgan and reassociation to eliminate a 'not' op. + if (match(Op0, m_OneUse(m_Or(m_Not(m_Value(X)), m_Constant(C2))))) { + // (~X | C2) ^ C1 --> ((X & ~C2) ^ -1) ^ C1 --> (X & ~C2) ^ ~C1 + Value *And = Builder.CreateAnd(X, ConstantExpr::getNot(C2)); + return BinaryOperator::CreateXor(And, ConstantExpr::getNot(C1)); + } + if (match(Op0, m_OneUse(m_And(m_Not(m_Value(X)), m_Constant(C2))))) { + // (~X & C2) ^ C1 --> ((X | ~C2) ^ -1) ^ C1 --> (X | ~C2) ^ ~C1 + Value *Or = Builder.CreateOr(X, ConstantExpr::getNot(C2)); + return BinaryOperator::CreateXor(Or, ConstantExpr::getNot(C1)); + } + + // Convert xor ([trunc] (ashr X, BW-1)), C => + // select(X >s -1, C, ~C) + // The ashr creates "AllZeroOrAllOne's", which then optionally inverses the + // constant depending on whether this input is less than 0. + const APInt *CA; + if (match(Op0, m_OneUse(m_TruncOrSelf( + m_AShr(m_Value(X), m_APIntAllowUndef(CA))))) && + *CA == X->getType()->getScalarSizeInBits() - 1 && + !match(C1, m_AllOnes())) { + assert(!C1->isZeroValue() && "Unexpected xor with 0"); + Value *IsNotNeg = Builder.CreateIsNotNeg(X); + return SelectInst::Create(IsNotNeg, Op1, Builder.CreateNot(Op1)); + } + } + + Type *Ty = I.getType(); + { + const APInt *RHSC; + if (match(Op1, m_APInt(RHSC))) { + Value *X; + const APInt *C; + // (C - X) ^ signmaskC --> (C + signmaskC) - X + if (RHSC->isSignMask() && match(Op0, m_Sub(m_APInt(C), m_Value(X)))) + return BinaryOperator::CreateSub(ConstantInt::get(Ty, *C + *RHSC), X); + + // (X + C) ^ signmaskC --> X + (C + signmaskC) + if (RHSC->isSignMask() && match(Op0, m_Add(m_Value(X), m_APInt(C)))) + return BinaryOperator::CreateAdd(X, ConstantInt::get(Ty, *C + *RHSC)); + + // (X | C) ^ RHSC --> X ^ (C ^ RHSC) iff X & C == 0 + if (match(Op0, m_Or(m_Value(X), m_APInt(C))) && + MaskedValueIsZero(X, *C, 0, &I)) + return BinaryOperator::CreateXor(X, ConstantInt::get(Ty, *C ^ *RHSC)); + + // If RHSC is inverting the remaining bits of shifted X, + // canonicalize to a 'not' before the shift to help SCEV and codegen: + // (X << C) ^ RHSC --> ~X << C + if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_APInt(C)))) && + *RHSC == APInt::getAllOnes(Ty->getScalarSizeInBits()).shl(*C)) { + Value *NotX = Builder.CreateNot(X); + return BinaryOperator::CreateShl(NotX, ConstantInt::get(Ty, *C)); + } + // (X >>u C) ^ RHSC --> ~X >>u C + if (match(Op0, m_OneUse(m_LShr(m_Value(X), m_APInt(C)))) && + *RHSC == APInt::getAllOnes(Ty->getScalarSizeInBits()).lshr(*C)) { + Value *NotX = Builder.CreateNot(X); + return BinaryOperator::CreateLShr(NotX, ConstantInt::get(Ty, *C)); + } + // TODO: We could handle 'ashr' here as well. That would be matching + // a 'not' op and moving it before the shift. Doing that requires + // preventing the inverse fold in canShiftBinOpWithConstantRHS(). + } + } + + // FIXME: This should not be limited to scalar (pull into APInt match above). + { + Value *X; + ConstantInt *C1, *C2, *C3; + // ((X^C1) >> C2) ^ C3 -> (X>>C2) ^ ((C1>>C2)^C3) + if (match(Op1, m_ConstantInt(C3)) && + match(Op0, m_LShr(m_Xor(m_Value(X), m_ConstantInt(C1)), + m_ConstantInt(C2))) && + Op0->hasOneUse()) { + // fold (C1 >> C2) ^ C3 + APInt FoldConst = C1->getValue().lshr(C2->getValue()); + FoldConst ^= C3->getValue(); + // Prepare the two operands. + auto *Opnd0 = Builder.CreateLShr(X, C2); + Opnd0->takeName(Op0); + return BinaryOperator::CreateXor(Opnd0, ConstantInt::get(Ty, FoldConst)); + } + } + + if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) + return FoldedLogic; + + // Y ^ (X | Y) --> X & ~Y + // Y ^ (Y | X) --> X & ~Y + if (match(Op1, m_OneUse(m_c_Or(m_Value(X), m_Specific(Op0))))) + return BinaryOperator::CreateAnd(X, Builder.CreateNot(Op0)); + // (X | Y) ^ Y --> X & ~Y + // (Y | X) ^ Y --> X & ~Y + if (match(Op0, m_OneUse(m_c_Or(m_Value(X), m_Specific(Op1))))) + return BinaryOperator::CreateAnd(X, Builder.CreateNot(Op1)); + + // Y ^ (X & Y) --> ~X & Y + // Y ^ (Y & X) --> ~X & Y + if (match(Op1, m_OneUse(m_c_And(m_Value(X), m_Specific(Op0))))) + return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(X)); + // (X & Y) ^ Y --> ~X & Y + // (Y & X) ^ Y --> ~X & Y + // Canonical form is (X & C) ^ C; don't touch that. + // TODO: A 'not' op is better for analysis and codegen, but demanded bits must + // be fixed to prefer that (otherwise we get infinite looping). + if (!match(Op1, m_Constant()) && + match(Op0, m_OneUse(m_c_And(m_Value(X), m_Specific(Op1))))) + return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(X)); + + Value *A, *B, *C; + // (A ^ B) ^ (A | C) --> (~A & C) ^ B -- There are 4 commuted variants. + if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_Value(A), m_Value(B))), + m_OneUse(m_c_Or(m_Deferred(A), m_Value(C)))))) + return BinaryOperator::CreateXor( + Builder.CreateAnd(Builder.CreateNot(A), C), B); + + // (A ^ B) ^ (B | C) --> (~B & C) ^ A -- There are 4 commuted variants. + if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_Value(A), m_Value(B))), + m_OneUse(m_c_Or(m_Deferred(B), m_Value(C)))))) + return BinaryOperator::CreateXor( + Builder.CreateAnd(Builder.CreateNot(B), C), A); + + // (A & B) ^ (A ^ B) -> (A | B) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_c_Xor(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); + // (A ^ B) ^ (A & B) -> (A | B) + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_c_And(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); + + // (A & ~B) ^ ~A -> ~(A & B) + // (~B & A) ^ ~A -> ~(A & B) + if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_Not(m_Specific(A)))) + return BinaryOperator::CreateNot(Builder.CreateAnd(A, B)); + + // (~A & B) ^ A --> A | B -- There are 4 commuted variants. + if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(A)), m_Value(B)), m_Deferred(A)))) + return BinaryOperator::CreateOr(A, B); + + // (~A | B) ^ A --> ~(A & B) + if (match(Op0, m_OneUse(m_c_Or(m_Not(m_Specific(Op1)), m_Value(B))))) + return BinaryOperator::CreateNot(Builder.CreateAnd(Op1, B)); + + // A ^ (~A | B) --> ~(A & B) + if (match(Op1, m_OneUse(m_c_Or(m_Not(m_Specific(Op0)), m_Value(B))))) + return BinaryOperator::CreateNot(Builder.CreateAnd(Op0, B)); + + // (A | B) ^ (A | C) --> (B ^ C) & ~A -- There are 4 commuted variants. + // TODO: Loosen one-use restriction if common operand is a constant. + Value *D; + if (match(Op0, m_OneUse(m_Or(m_Value(A), m_Value(B)))) && + match(Op1, m_OneUse(m_Or(m_Value(C), m_Value(D))))) { + if (B == C || B == D) + std::swap(A, B); + if (A == C) + std::swap(C, D); + if (A == D) { + Value *NotA = Builder.CreateNot(A); + return BinaryOperator::CreateAnd(Builder.CreateXor(B, C), NotA); + } + } + + if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) + if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) + if (Value *V = foldXorOfICmps(LHS, RHS, I)) + return replaceInstUsesWith(I, V); + + if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) + return CastedXor; + + if (Instruction *Abs = canonicalizeAbs(I, Builder)) + return Abs; + + // Otherwise, if all else failed, try to hoist the xor-by-constant: + // (X ^ C) ^ Y --> (X ^ Y) ^ C + // Just like we do in other places, we completely avoid the fold + // for constantexprs, at least to avoid endless combine loop. + if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_CombineAnd(m_Value(X), + m_Unless(m_ConstantExpr())), + m_ImmConstant(C1))), + m_Value(Y)))) + return BinaryOperator::CreateXor(Builder.CreateXor(X, Y), C1); + + return nullptr; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp new file mode 100644 index 000000000000..2540e545ae4d --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp @@ -0,0 +1,158 @@ +//===- InstCombineAtomicRMW.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the visit functions for atomic rmw instructions. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +namespace { +/// Return true if and only if the given instruction does not modify the memory +/// location referenced. Note that an idemptent atomicrmw may still have +/// ordering effects on nearby instructions, or be volatile. +/// TODO: Common w/ the version in AtomicExpandPass, and change the term used. +/// Idemptotent is confusing in this context. +bool isIdempotentRMW(AtomicRMWInst& RMWI) { + if (auto CF = dyn_cast<ConstantFP>(RMWI.getValOperand())) + switch(RMWI.getOperation()) { + case AtomicRMWInst::FAdd: // -0.0 + return CF->isZero() && CF->isNegative(); + case AtomicRMWInst::FSub: // +0.0 + return CF->isZero() && !CF->isNegative(); + default: + return false; + }; + + auto C = dyn_cast<ConstantInt>(RMWI.getValOperand()); + if(!C) + return false; + + switch(RMWI.getOperation()) { + case AtomicRMWInst::Add: + case AtomicRMWInst::Sub: + case AtomicRMWInst::Or: + case AtomicRMWInst::Xor: + return C->isZero(); + case AtomicRMWInst::And: + return C->isMinusOne(); + case AtomicRMWInst::Min: + return C->isMaxValue(true); + case AtomicRMWInst::Max: + return C->isMinValue(true); + case AtomicRMWInst::UMin: + return C->isMaxValue(false); + case AtomicRMWInst::UMax: + return C->isMinValue(false); + default: + return false; + } +} + +/// Return true if the given instruction always produces a value in memory +/// equivalent to its value operand. +bool isSaturating(AtomicRMWInst& RMWI) { + if (auto CF = dyn_cast<ConstantFP>(RMWI.getValOperand())) + switch(RMWI.getOperation()) { + case AtomicRMWInst::FAdd: + case AtomicRMWInst::FSub: + return CF->isNaN(); + default: + return false; + }; + + auto C = dyn_cast<ConstantInt>(RMWI.getValOperand()); + if(!C) + return false; + + switch(RMWI.getOperation()) { + default: + return false; + case AtomicRMWInst::Xchg: + return true; + case AtomicRMWInst::Or: + return C->isAllOnesValue(); + case AtomicRMWInst::And: + return C->isZero(); + case AtomicRMWInst::Min: + return C->isMinValue(true); + case AtomicRMWInst::Max: + return C->isMaxValue(true); + case AtomicRMWInst::UMin: + return C->isMinValue(false); + case AtomicRMWInst::UMax: + return C->isMaxValue(false); + }; +} +} // namespace + +Instruction *InstCombinerImpl::visitAtomicRMWInst(AtomicRMWInst &RMWI) { + + // Volatile RMWs perform a load and a store, we cannot replace this by just a + // load or just a store. We chose not to canonicalize out of general paranoia + // about user expectations around volatile. + if (RMWI.isVolatile()) + return nullptr; + + // Any atomicrmw op which produces a known result in memory can be + // replaced w/an atomicrmw xchg. + if (isSaturating(RMWI) && + RMWI.getOperation() != AtomicRMWInst::Xchg) { + RMWI.setOperation(AtomicRMWInst::Xchg); + return &RMWI; + } + + AtomicOrdering Ordering = RMWI.getOrdering(); + assert(Ordering != AtomicOrdering::NotAtomic && + Ordering != AtomicOrdering::Unordered && + "AtomicRMWs don't make sense with Unordered or NotAtomic"); + + // Any atomicrmw xchg with no uses can be converted to a atomic store if the + // ordering is compatible. + if (RMWI.getOperation() == AtomicRMWInst::Xchg && + RMWI.use_empty()) { + if (Ordering != AtomicOrdering::Release && + Ordering != AtomicOrdering::Monotonic) + return nullptr; + auto *SI = new StoreInst(RMWI.getValOperand(), + RMWI.getPointerOperand(), &RMWI); + SI->setAtomic(Ordering, RMWI.getSyncScopeID()); + SI->setAlignment(DL.getABITypeAlign(RMWI.getType())); + return eraseInstFromFunction(RMWI); + } + + if (!isIdempotentRMW(RMWI)) + return nullptr; + + // We chose to canonicalize all idempotent operations to an single + // operation code and constant. This makes it easier for the rest of the + // optimizer to match easily. The choices of or w/0 and fadd w/-0.0 are + // arbitrary. + if (RMWI.getType()->isIntegerTy() && + RMWI.getOperation() != AtomicRMWInst::Or) { + RMWI.setOperation(AtomicRMWInst::Or); + return replaceOperand(RMWI, 1, ConstantInt::get(RMWI.getType(), 0)); + } else if (RMWI.getType()->isFloatingPointTy() && + RMWI.getOperation() != AtomicRMWInst::FAdd) { + RMWI.setOperation(AtomicRMWInst::FAdd); + return replaceOperand(RMWI, 1, ConstantFP::getNegativeZero(RMWI.getType())); + } + + // Check if the required ordering is compatible with an atomic load. + if (Ordering != AtomicOrdering::Acquire && + Ordering != AtomicOrdering::Monotonic) + return nullptr; + + LoadInst *Load = new LoadInst(RMWI.getType(), RMWI.getPointerOperand(), "", + false, DL.getABITypeAlign(RMWI.getType()), + Ordering, RMWI.getSyncScopeID()); + return Load; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp new file mode 100644 index 000000000000..67ef2e895b6c --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -0,0 +1,3630 @@ +//===- InstCombineCalls.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the visitCall, visitInvoke, and visitCallBr functions. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumeBundleQueries.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsAArch64.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/IntrinsicsARM.h" +#include "llvm/IR/IntrinsicsHexagon.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Statepoint.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/SimplifyLibCalls.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <utility> +#include <vector> + +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + +using namespace llvm; +using namespace PatternMatch; + +STATISTIC(NumSimplified, "Number of library calls simplified"); + +static cl::opt<unsigned> GuardWideningWindow( + "instcombine-guard-widening-window", + cl::init(3), + cl::desc("How wide an instruction window to bypass looking for " + "another guard")); + +namespace llvm { +/// enable preservation of attributes in assume like: +/// call void @llvm.assume(i1 true) [ "nonnull"(i32* %PTR) ] +extern cl::opt<bool> EnableKnowledgeRetention; +} // namespace llvm + +/// Return the specified type promoted as it would be to pass though a va_arg +/// area. +static Type *getPromotedType(Type *Ty) { + if (IntegerType* ITy = dyn_cast<IntegerType>(Ty)) { + if (ITy->getBitWidth() < 32) + return Type::getInt32Ty(Ty->getContext()); + } + return Ty; +} + +/// Recognize a memcpy/memmove from a trivially otherwise unused alloca. +/// TODO: This should probably be integrated with visitAllocSites, but that +/// requires a deeper change to allow either unread or unwritten objects. +static bool hasUndefSource(AnyMemTransferInst *MI) { + auto *Src = MI->getRawSource(); + while (isa<GetElementPtrInst>(Src) || isa<BitCastInst>(Src)) { + if (!Src->hasOneUse()) + return false; + Src = cast<Instruction>(Src)->getOperand(0); + } + return isa<AllocaInst>(Src) && Src->hasOneUse(); +} + +Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { + Align DstAlign = getKnownAlignment(MI->getRawDest(), DL, MI, &AC, &DT); + MaybeAlign CopyDstAlign = MI->getDestAlign(); + if (!CopyDstAlign || *CopyDstAlign < DstAlign) { + MI->setDestAlignment(DstAlign); + return MI; + } + + Align SrcAlign = getKnownAlignment(MI->getRawSource(), DL, MI, &AC, &DT); + MaybeAlign CopySrcAlign = MI->getSourceAlign(); + if (!CopySrcAlign || *CopySrcAlign < SrcAlign) { + MI->setSourceAlignment(SrcAlign); + return MI; + } + + // If we have a store to a location which is known constant, we can conclude + // that the store must be storing the constant value (else the memory + // wouldn't be constant), and this must be a noop. + if (AA->pointsToConstantMemory(MI->getDest())) { + // Set the size of the copy to 0, it will be deleted on the next iteration. + MI->setLength(Constant::getNullValue(MI->getLength()->getType())); + return MI; + } + + // If the source is provably undef, the memcpy/memmove doesn't do anything + // (unless the transfer is volatile). + if (hasUndefSource(MI) && !MI->isVolatile()) { + // Set the size of the copy to 0, it will be deleted on the next iteration. + MI->setLength(Constant::getNullValue(MI->getLength()->getType())); + return MI; + } + + // If MemCpyInst length is 1/2/4/8 bytes then replace memcpy with + // load/store. + ConstantInt *MemOpLength = dyn_cast<ConstantInt>(MI->getLength()); + if (!MemOpLength) return nullptr; + + // Source and destination pointer types are always "i8*" for intrinsic. See + // if the size is something we can handle with a single primitive load/store. + // A single load+store correctly handles overlapping memory in the memmove + // case. + uint64_t Size = MemOpLength->getLimitedValue(); + assert(Size && "0-sized memory transferring should be removed already."); + + if (Size > 8 || (Size&(Size-1))) + return nullptr; // If not 1/2/4/8 bytes, exit. + + // If it is an atomic and alignment is less than the size then we will + // introduce the unaligned memory access which will be later transformed + // into libcall in CodeGen. This is not evident performance gain so disable + // it now. + if (isa<AtomicMemTransferInst>(MI)) + if (*CopyDstAlign < Size || *CopySrcAlign < Size) + return nullptr; + + // Use an integer load+store unless we can find something better. + unsigned SrcAddrSp = + cast<PointerType>(MI->getArgOperand(1)->getType())->getAddressSpace(); + unsigned DstAddrSp = + cast<PointerType>(MI->getArgOperand(0)->getType())->getAddressSpace(); + + IntegerType* IntType = IntegerType::get(MI->getContext(), Size<<3); + Type *NewSrcPtrTy = PointerType::get(IntType, SrcAddrSp); + Type *NewDstPtrTy = PointerType::get(IntType, DstAddrSp); + + // If the memcpy has metadata describing the members, see if we can get the + // TBAA tag describing our copy. + MDNode *CopyMD = nullptr; + if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa)) { + CopyMD = M; + } else if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) { + if (M->getNumOperands() == 3 && M->getOperand(0) && + mdconst::hasa<ConstantInt>(M->getOperand(0)) && + mdconst::extract<ConstantInt>(M->getOperand(0))->isZero() && + M->getOperand(1) && + mdconst::hasa<ConstantInt>(M->getOperand(1)) && + mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() == + Size && + M->getOperand(2) && isa<MDNode>(M->getOperand(2))) + CopyMD = cast<MDNode>(M->getOperand(2)); + } + + Value *Src = Builder.CreateBitCast(MI->getArgOperand(1), NewSrcPtrTy); + Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy); + LoadInst *L = Builder.CreateLoad(IntType, Src); + // Alignment from the mem intrinsic will be better, so use it. + L->setAlignment(*CopySrcAlign); + if (CopyMD) + L->setMetadata(LLVMContext::MD_tbaa, CopyMD); + MDNode *LoopMemParallelMD = + MI->getMetadata(LLVMContext::MD_mem_parallel_loop_access); + if (LoopMemParallelMD) + L->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); + MDNode *AccessGroupMD = MI->getMetadata(LLVMContext::MD_access_group); + if (AccessGroupMD) + L->setMetadata(LLVMContext::MD_access_group, AccessGroupMD); + + StoreInst *S = Builder.CreateStore(L, Dest); + // Alignment from the mem intrinsic will be better, so use it. + S->setAlignment(*CopyDstAlign); + if (CopyMD) + S->setMetadata(LLVMContext::MD_tbaa, CopyMD); + if (LoopMemParallelMD) + S->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); + if (AccessGroupMD) + S->setMetadata(LLVMContext::MD_access_group, AccessGroupMD); + + if (auto *MT = dyn_cast<MemTransferInst>(MI)) { + // non-atomics can be volatile + L->setVolatile(MT->isVolatile()); + S->setVolatile(MT->isVolatile()); + } + if (isa<AtomicMemTransferInst>(MI)) { + // atomics have to be unordered + L->setOrdering(AtomicOrdering::Unordered); + S->setOrdering(AtomicOrdering::Unordered); + } + + // Set the size of the copy to 0, it will be deleted on the next iteration. + MI->setLength(Constant::getNullValue(MemOpLength->getType())); + return MI; +} + +Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { + const Align KnownAlignment = + getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT); + MaybeAlign MemSetAlign = MI->getDestAlign(); + if (!MemSetAlign || *MemSetAlign < KnownAlignment) { + MI->setDestAlignment(KnownAlignment); + return MI; + } + + // If we have a store to a location which is known constant, we can conclude + // that the store must be storing the constant value (else the memory + // wouldn't be constant), and this must be a noop. + if (AA->pointsToConstantMemory(MI->getDest())) { + // Set the size of the copy to 0, it will be deleted on the next iteration. + MI->setLength(Constant::getNullValue(MI->getLength()->getType())); + return MI; + } + + // Remove memset with an undef value. + // FIXME: This is technically incorrect because it might overwrite a poison + // value. Change to PoisonValue once #52930 is resolved. + if (isa<UndefValue>(MI->getValue())) { + // Set the size of the copy to 0, it will be deleted on the next iteration. + MI->setLength(Constant::getNullValue(MI->getLength()->getType())); + return MI; + } + + // Extract the length and alignment and fill if they are constant. + ConstantInt *LenC = dyn_cast<ConstantInt>(MI->getLength()); + ConstantInt *FillC = dyn_cast<ConstantInt>(MI->getValue()); + if (!LenC || !FillC || !FillC->getType()->isIntegerTy(8)) + return nullptr; + const uint64_t Len = LenC->getLimitedValue(); + assert(Len && "0-sized memory setting should be removed already."); + const Align Alignment = MI->getDestAlign().valueOrOne(); + + // If it is an atomic and alignment is less than the size then we will + // introduce the unaligned memory access which will be later transformed + // into libcall in CodeGen. This is not evident performance gain so disable + // it now. + if (isa<AtomicMemSetInst>(MI)) + if (Alignment < Len) + return nullptr; + + // memset(s,c,n) -> store s, c (for n=1,2,4,8) + if (Len <= 8 && isPowerOf2_32((uint32_t)Len)) { + Type *ITy = IntegerType::get(MI->getContext(), Len*8); // n=1 -> i8. + + Value *Dest = MI->getDest(); + unsigned DstAddrSp = cast<PointerType>(Dest->getType())->getAddressSpace(); + Type *NewDstPtrTy = PointerType::get(ITy, DstAddrSp); + Dest = Builder.CreateBitCast(Dest, NewDstPtrTy); + + // Extract the fill value and store. + uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL; + StoreInst *S = Builder.CreateStore(ConstantInt::get(ITy, Fill), Dest, + MI->isVolatile()); + S->setAlignment(Alignment); + if (isa<AtomicMemSetInst>(MI)) + S->setOrdering(AtomicOrdering::Unordered); + + // Set the size of the copy to 0, it will be deleted on the next iteration. + MI->setLength(Constant::getNullValue(LenC->getType())); + return MI; + } + + return nullptr; +} + +// TODO, Obvious Missing Transforms: +// * Narrow width by halfs excluding zero/undef lanes +Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { + Value *LoadPtr = II.getArgOperand(0); + const Align Alignment = + cast<ConstantInt>(II.getArgOperand(1))->getAlignValue(); + + // If the mask is all ones or undefs, this is a plain vector load of the 1st + // argument. + if (maskIsAllOneOrUndef(II.getArgOperand(2))) { + LoadInst *L = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, + "unmaskedload"); + L->copyMetadata(II); + return L; + } + + // If we can unconditionally load from this address, replace with a + // load/select idiom. TODO: use DT for context sensitive query + if (isDereferenceablePointer(LoadPtr, II.getType(), + II.getModule()->getDataLayout(), &II, nullptr)) { + LoadInst *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, + "unmaskedload"); + LI->copyMetadata(II); + return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3)); + } + + return nullptr; +} + +// TODO, Obvious Missing Transforms: +// * Single constant active lane -> store +// * Narrow width by halfs excluding zero/undef lanes +Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) { + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + if (!ConstMask) + return nullptr; + + // If the mask is all zeros, this instruction does nothing. + if (ConstMask->isNullValue()) + return eraseInstFromFunction(II); + + // If the mask is all ones, this is a plain vector store of the 1st argument. + if (ConstMask->isAllOnesValue()) { + Value *StorePtr = II.getArgOperand(1); + Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); + StoreInst *S = + new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment); + S->copyMetadata(II); + return S; + } + + if (isa<ScalableVectorType>(ConstMask->getType())) + return nullptr; + + // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts + APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); + APInt UndefElts(DemandedElts.getBitWidth(), 0); + if (Value *V = + SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts, UndefElts)) + return replaceOperand(II, 0, V); + + return nullptr; +} + +// TODO, Obvious Missing Transforms: +// * Single constant active lane load -> load +// * Dereferenceable address & few lanes -> scalarize speculative load/selects +// * Adjacent vector addresses -> masked.load +// * Narrow width by halfs excluding zero/undef lanes +// * Vector incrementing address -> vector masked load +Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) { + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); + if (!ConstMask) + return nullptr; + + // Vector splat address w/known mask -> scalar load + // Fold the gather to load the source vector first lane + // because it is reloading the same value each time + if (ConstMask->isAllOnesValue()) + if (auto *SplatPtr = getSplatValue(II.getArgOperand(0))) { + auto *VecTy = cast<VectorType>(II.getType()); + const Align Alignment = + cast<ConstantInt>(II.getArgOperand(1))->getAlignValue(); + LoadInst *L = Builder.CreateAlignedLoad(VecTy->getElementType(), SplatPtr, + Alignment, "load.scalar"); + Value *Shuf = + Builder.CreateVectorSplat(VecTy->getElementCount(), L, "broadcast"); + return replaceInstUsesWith(II, cast<Instruction>(Shuf)); + } + + return nullptr; +} + +// TODO, Obvious Missing Transforms: +// * Single constant active lane -> store +// * Adjacent vector addresses -> masked.store +// * Narrow store width by halfs excluding zero/undef lanes +// * Vector incrementing address -> vector masked store +Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + if (!ConstMask) + return nullptr; + + // If the mask is all zeros, a scatter does nothing. + if (ConstMask->isNullValue()) + return eraseInstFromFunction(II); + + // Vector splat address -> scalar store + if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) { + // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr + if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) { + Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); + StoreInst *S = + new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment); + S->copyMetadata(II); + return S; + } + // scatter(vector, splat(ptr), splat(true)) -> store extract(vector, + // lastlane), ptr + if (ConstMask->isAllOnesValue()) { + Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); + VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType()); + ElementCount VF = WideLoadTy->getElementCount(); + Constant *EC = + ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue()); + Value *RunTimeVF = VF.isScalable() ? Builder.CreateVScale(EC) : EC; + Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1)); + Value *Extract = + Builder.CreateExtractElement(II.getArgOperand(0), LastLane); + StoreInst *S = + new StoreInst(Extract, SplatPtr, /*IsVolatile=*/false, Alignment); + S->copyMetadata(II); + return S; + } + } + if (isa<ScalableVectorType>(ConstMask->getType())) + return nullptr; + + // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts + APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); + APInt UndefElts(DemandedElts.getBitWidth(), 0); + if (Value *V = + SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts, UndefElts)) + return replaceOperand(II, 0, V); + if (Value *V = + SimplifyDemandedVectorElts(II.getOperand(1), DemandedElts, UndefElts)) + return replaceOperand(II, 1, V); + + return nullptr; +} + +/// This function transforms launder.invariant.group and strip.invariant.group +/// like: +/// launder(launder(%x)) -> launder(%x) (the result is not the argument) +/// launder(strip(%x)) -> launder(%x) +/// strip(strip(%x)) -> strip(%x) (the result is not the argument) +/// strip(launder(%x)) -> strip(%x) +/// This is legal because it preserves the most recent information about +/// the presence or absence of invariant.group. +static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II, + InstCombinerImpl &IC) { + auto *Arg = II.getArgOperand(0); + auto *StrippedArg = Arg->stripPointerCasts(); + auto *StrippedInvariantGroupsArg = StrippedArg; + while (auto *Intr = dyn_cast<IntrinsicInst>(StrippedInvariantGroupsArg)) { + if (Intr->getIntrinsicID() != Intrinsic::launder_invariant_group && + Intr->getIntrinsicID() != Intrinsic::strip_invariant_group) + break; + StrippedInvariantGroupsArg = Intr->getArgOperand(0)->stripPointerCasts(); + } + if (StrippedArg == StrippedInvariantGroupsArg) + return nullptr; // No launders/strips to remove. + + Value *Result = nullptr; + + if (II.getIntrinsicID() == Intrinsic::launder_invariant_group) + Result = IC.Builder.CreateLaunderInvariantGroup(StrippedInvariantGroupsArg); + else if (II.getIntrinsicID() == Intrinsic::strip_invariant_group) + Result = IC.Builder.CreateStripInvariantGroup(StrippedInvariantGroupsArg); + else + llvm_unreachable( + "simplifyInvariantGroupIntrinsic only handles launder and strip"); + if (Result->getType()->getPointerAddressSpace() != + II.getType()->getPointerAddressSpace()) + Result = IC.Builder.CreateAddrSpaceCast(Result, II.getType()); + if (Result->getType() != II.getType()) + Result = IC.Builder.CreateBitCast(Result, II.getType()); + + return cast<Instruction>(Result); +} + +static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { + assert((II.getIntrinsicID() == Intrinsic::cttz || + II.getIntrinsicID() == Intrinsic::ctlz) && + "Expected cttz or ctlz intrinsic"); + bool IsTZ = II.getIntrinsicID() == Intrinsic::cttz; + Value *Op0 = II.getArgOperand(0); + Value *Op1 = II.getArgOperand(1); + Value *X; + // ctlz(bitreverse(x)) -> cttz(x) + // cttz(bitreverse(x)) -> ctlz(x) + if (match(Op0, m_BitReverse(m_Value(X)))) { + Intrinsic::ID ID = IsTZ ? Intrinsic::ctlz : Intrinsic::cttz; + Function *F = Intrinsic::getDeclaration(II.getModule(), ID, II.getType()); + return CallInst::Create(F, {X, II.getArgOperand(1)}); + } + + if (II.getType()->isIntOrIntVectorTy(1)) { + // ctlz/cttz i1 Op0 --> not Op0 + if (match(Op1, m_Zero())) + return BinaryOperator::CreateNot(Op0); + // If zero is poison, then the input can be assumed to be "true", so the + // instruction simplifies to "false". + assert(match(Op1, m_One()) && "Expected ctlz/cttz operand to be 0 or 1"); + return IC.replaceInstUsesWith(II, ConstantInt::getNullValue(II.getType())); + } + + // If the operand is a select with constant arm(s), try to hoist ctlz/cttz. + if (auto *Sel = dyn_cast<SelectInst>(Op0)) + if (Instruction *R = IC.FoldOpIntoSelect(II, Sel)) + return R; + + if (IsTZ) { + // cttz(-x) -> cttz(x) + if (match(Op0, m_Neg(m_Value(X)))) + return IC.replaceOperand(II, 0, X); + + // cttz(sext(x)) -> cttz(zext(x)) + if (match(Op0, m_OneUse(m_SExt(m_Value(X))))) { + auto *Zext = IC.Builder.CreateZExt(X, II.getType()); + auto *CttzZext = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::cttz, Zext, Op1); + return IC.replaceInstUsesWith(II, CttzZext); + } + + // Zext doesn't change the number of trailing zeros, so narrow: + // cttz(zext(x)) -> zext(cttz(x)) if the 'ZeroIsPoison' parameter is 'true'. + if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) && match(Op1, m_One())) { + auto *Cttz = IC.Builder.CreateBinaryIntrinsic(Intrinsic::cttz, X, + IC.Builder.getTrue()); + auto *ZextCttz = IC.Builder.CreateZExt(Cttz, II.getType()); + return IC.replaceInstUsesWith(II, ZextCttz); + } + + // cttz(abs(x)) -> cttz(x) + // cttz(nabs(x)) -> cttz(x) + Value *Y; + SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor; + if (SPF == SPF_ABS || SPF == SPF_NABS) + return IC.replaceOperand(II, 0, X); + + if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) + return IC.replaceOperand(II, 0, X); + } + + KnownBits Known = IC.computeKnownBits(Op0, 0, &II); + + // Create a mask for bits above (ctlz) or below (cttz) the first known one. + unsigned PossibleZeros = IsTZ ? Known.countMaxTrailingZeros() + : Known.countMaxLeadingZeros(); + unsigned DefiniteZeros = IsTZ ? Known.countMinTrailingZeros() + : Known.countMinLeadingZeros(); + + // If all bits above (ctlz) or below (cttz) the first known one are known + // zero, this value is constant. + // FIXME: This should be in InstSimplify because we're replacing an + // instruction with a constant. + if (PossibleZeros == DefiniteZeros) { + auto *C = ConstantInt::get(Op0->getType(), DefiniteZeros); + return IC.replaceInstUsesWith(II, C); + } + + // If the input to cttz/ctlz is known to be non-zero, + // then change the 'ZeroIsPoison' parameter to 'true' + // because we know the zero behavior can't affect the result. + if (!Known.One.isZero() || + isKnownNonZero(Op0, IC.getDataLayout(), 0, &IC.getAssumptionCache(), &II, + &IC.getDominatorTree())) { + if (!match(II.getArgOperand(1), m_One())) + return IC.replaceOperand(II, 1, IC.Builder.getTrue()); + } + + // Add range metadata since known bits can't completely reflect what we know. + // TODO: Handle splat vectors. + auto *IT = dyn_cast<IntegerType>(Op0->getType()); + if (IT && IT->getBitWidth() != 1 && !II.getMetadata(LLVMContext::MD_range)) { + Metadata *LowAndHigh[] = { + ConstantAsMetadata::get(ConstantInt::get(IT, DefiniteZeros)), + ConstantAsMetadata::get(ConstantInt::get(IT, PossibleZeros + 1))}; + II.setMetadata(LLVMContext::MD_range, + MDNode::get(II.getContext(), LowAndHigh)); + return &II; + } + + return nullptr; +} + +static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) { + assert(II.getIntrinsicID() == Intrinsic::ctpop && + "Expected ctpop intrinsic"); + Type *Ty = II.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + Value *Op0 = II.getArgOperand(0); + Value *X, *Y; + + // ctpop(bitreverse(x)) -> ctpop(x) + // ctpop(bswap(x)) -> ctpop(x) + if (match(Op0, m_BitReverse(m_Value(X))) || match(Op0, m_BSwap(m_Value(X)))) + return IC.replaceOperand(II, 0, X); + + // ctpop(rot(x)) -> ctpop(x) + if ((match(Op0, m_FShl(m_Value(X), m_Value(Y), m_Value())) || + match(Op0, m_FShr(m_Value(X), m_Value(Y), m_Value()))) && + X == Y) + return IC.replaceOperand(II, 0, X); + + // ctpop(x | -x) -> bitwidth - cttz(x, false) + if (Op0->hasOneUse() && + match(Op0, m_c_Or(m_Value(X), m_Neg(m_Deferred(X))))) { + Function *F = + Intrinsic::getDeclaration(II.getModule(), Intrinsic::cttz, Ty); + auto *Cttz = IC.Builder.CreateCall(F, {X, IC.Builder.getFalse()}); + auto *Bw = ConstantInt::get(Ty, APInt(BitWidth, BitWidth)); + return IC.replaceInstUsesWith(II, IC.Builder.CreateSub(Bw, Cttz)); + } + + // ctpop(~x & (x - 1)) -> cttz(x, false) + if (match(Op0, + m_c_And(m_Not(m_Value(X)), m_Add(m_Deferred(X), m_AllOnes())))) { + Function *F = + Intrinsic::getDeclaration(II.getModule(), Intrinsic::cttz, Ty); + return CallInst::Create(F, {X, IC.Builder.getFalse()}); + } + + // Zext doesn't change the number of set bits, so narrow: + // ctpop (zext X) --> zext (ctpop X) + if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) { + Value *NarrowPop = IC.Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, X); + return CastInst::Create(Instruction::ZExt, NarrowPop, Ty); + } + + // If the operand is a select with constant arm(s), try to hoist ctpop. + if (auto *Sel = dyn_cast<SelectInst>(Op0)) + if (Instruction *R = IC.FoldOpIntoSelect(II, Sel)) + return R; + + KnownBits Known(BitWidth); + IC.computeKnownBits(Op0, Known, 0, &II); + + // If all bits are zero except for exactly one fixed bit, then the result + // must be 0 or 1, and we can get that answer by shifting to LSB: + // ctpop (X & 32) --> (X & 32) >> 5 + if ((~Known.Zero).isPowerOf2()) + return BinaryOperator::CreateLShr( + Op0, ConstantInt::get(Ty, (~Known.Zero).exactLogBase2())); + + // FIXME: Try to simplify vectors of integers. + auto *IT = dyn_cast<IntegerType>(Ty); + if (!IT) + return nullptr; + + // Add range metadata since known bits can't completely reflect what we know. + unsigned MinCount = Known.countMinPopulation(); + unsigned MaxCount = Known.countMaxPopulation(); + if (IT->getBitWidth() != 1 && !II.getMetadata(LLVMContext::MD_range)) { + Metadata *LowAndHigh[] = { + ConstantAsMetadata::get(ConstantInt::get(IT, MinCount)), + ConstantAsMetadata::get(ConstantInt::get(IT, MaxCount + 1))}; + II.setMetadata(LLVMContext::MD_range, + MDNode::get(II.getContext(), LowAndHigh)); + return &II; + } + + return nullptr; +} + +/// Convert a table lookup to shufflevector if the mask is constant. +/// This could benefit tbl1 if the mask is { 7,6,5,4,3,2,1,0 }, in +/// which case we could lower the shufflevector with rev64 instructions +/// as it's actually a byte reverse. +static Value *simplifyNeonTbl1(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + // Bail out if the mask is not a constant. + auto *C = dyn_cast<Constant>(II.getArgOperand(1)); + if (!C) + return nullptr; + + auto *VecTy = cast<FixedVectorType>(II.getType()); + unsigned NumElts = VecTy->getNumElements(); + + // Only perform this transformation for <8 x i8> vector types. + if (!VecTy->getElementType()->isIntegerTy(8) || NumElts != 8) + return nullptr; + + int Indexes[8]; + + for (unsigned I = 0; I < NumElts; ++I) { + Constant *COp = C->getAggregateElement(I); + + if (!COp || !isa<ConstantInt>(COp)) + return nullptr; + + Indexes[I] = cast<ConstantInt>(COp)->getLimitedValue(); + + // Make sure the mask indices are in range. + if ((unsigned)Indexes[I] >= NumElts) + return nullptr; + } + + auto *V1 = II.getArgOperand(0); + auto *V2 = Constant::getNullValue(V1->getType()); + return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes)); +} + +// Returns true iff the 2 intrinsics have the same operands, limiting the +// comparison to the first NumOperands. +static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, + unsigned NumOperands) { + assert(I.arg_size() >= NumOperands && "Not enough operands"); + assert(E.arg_size() >= NumOperands && "Not enough operands"); + for (unsigned i = 0; i < NumOperands; i++) + if (I.getArgOperand(i) != E.getArgOperand(i)) + return false; + return true; +} + +// Remove trivially empty start/end intrinsic ranges, i.e. a start +// immediately followed by an end (ignoring debuginfo or other +// start/end intrinsics in between). As this handles only the most trivial +// cases, tracking the nesting level is not needed: +// +// call @llvm.foo.start(i1 0) +// call @llvm.foo.start(i1 0) ; This one won't be skipped: it will be removed +// call @llvm.foo.end(i1 0) +// call @llvm.foo.end(i1 0) ; &I +static bool +removeTriviallyEmptyRange(IntrinsicInst &EndI, InstCombinerImpl &IC, + std::function<bool(const IntrinsicInst &)> IsStart) { + // We start from the end intrinsic and scan backwards, so that InstCombine + // has already processed (and potentially removed) all the instructions + // before the end intrinsic. + BasicBlock::reverse_iterator BI(EndI), BE(EndI.getParent()->rend()); + for (; BI != BE; ++BI) { + if (auto *I = dyn_cast<IntrinsicInst>(&*BI)) { + if (I->isDebugOrPseudoInst() || + I->getIntrinsicID() == EndI.getIntrinsicID()) + continue; + if (IsStart(*I)) { + if (haveSameOperands(EndI, *I, EndI.arg_size())) { + IC.eraseInstFromFunction(*I); + IC.eraseInstFromFunction(EndI); + return true; + } + // Skip start intrinsics that don't pair with this end intrinsic. + continue; + } + } + break; + } + + return false; +} + +Instruction *InstCombinerImpl::visitVAEndInst(VAEndInst &I) { + removeTriviallyEmptyRange(I, *this, [](const IntrinsicInst &I) { + return I.getIntrinsicID() == Intrinsic::vastart || + I.getIntrinsicID() == Intrinsic::vacopy; + }); + return nullptr; +} + +static CallInst *canonicalizeConstantArg0ToArg1(CallInst &Call) { + assert(Call.arg_size() > 1 && "Need at least 2 args to swap"); + Value *Arg0 = Call.getArgOperand(0), *Arg1 = Call.getArgOperand(1); + if (isa<Constant>(Arg0) && !isa<Constant>(Arg1)) { + Call.setArgOperand(0, Arg1); + Call.setArgOperand(1, Arg0); + return &Call; + } + return nullptr; +} + +/// Creates a result tuple for an overflow intrinsic \p II with a given +/// \p Result and a constant \p Overflow value. +static Instruction *createOverflowTuple(IntrinsicInst *II, Value *Result, + Constant *Overflow) { + Constant *V[] = {PoisonValue::get(Result->getType()), Overflow}; + StructType *ST = cast<StructType>(II->getType()); + Constant *Struct = ConstantStruct::get(ST, V); + return InsertValueInst::Create(Struct, Result, 0); +} + +Instruction * +InstCombinerImpl::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) { + WithOverflowInst *WO = cast<WithOverflowInst>(II); + Value *OperationResult = nullptr; + Constant *OverflowResult = nullptr; + if (OptimizeOverflowCheck(WO->getBinaryOp(), WO->isSigned(), WO->getLHS(), + WO->getRHS(), *WO, OperationResult, OverflowResult)) + return createOverflowTuple(WO, OperationResult, OverflowResult); + return nullptr; +} + +static Optional<bool> getKnownSign(Value *Op, Instruction *CxtI, + const DataLayout &DL, AssumptionCache *AC, + DominatorTree *DT) { + KnownBits Known = computeKnownBits(Op, DL, 0, AC, CxtI, DT); + if (Known.isNonNegative()) + return false; + if (Known.isNegative()) + return true; + + Value *X, *Y; + if (match(Op, m_NSWSub(m_Value(X), m_Value(Y)))) + return isImpliedByDomCondition(ICmpInst::ICMP_SLT, X, Y, CxtI, DL); + + return isImpliedByDomCondition( + ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); +} + +/// Try to canonicalize min/max(X + C0, C1) as min/max(X, C1 - C0) + C0. This +/// can trigger other combines. +static Instruction *moveAddAfterMinMax(IntrinsicInst *II, + InstCombiner::BuilderTy &Builder) { + Intrinsic::ID MinMaxID = II->getIntrinsicID(); + assert((MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin || + MinMaxID == Intrinsic::umax || MinMaxID == Intrinsic::umin) && + "Expected a min or max intrinsic"); + + // TODO: Match vectors with undef elements, but undef may not propagate. + Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1); + Value *X; + const APInt *C0, *C1; + if (!match(Op0, m_OneUse(m_Add(m_Value(X), m_APInt(C0)))) || + !match(Op1, m_APInt(C1))) + return nullptr; + + // Check for necessary no-wrap and overflow constraints. + bool IsSigned = MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin; + auto *Add = cast<BinaryOperator>(Op0); + if ((IsSigned && !Add->hasNoSignedWrap()) || + (!IsSigned && !Add->hasNoUnsignedWrap())) + return nullptr; + + // If the constant difference overflows, then instsimplify should reduce the + // min/max to the add or C1. + bool Overflow; + APInt CDiff = + IsSigned ? C1->ssub_ov(*C0, Overflow) : C1->usub_ov(*C0, Overflow); + assert(!Overflow && "Expected simplify of min/max"); + + // min/max (add X, C0), C1 --> add (min/max X, C1 - C0), C0 + // Note: the "mismatched" no-overflow setting does not propagate. + Constant *NewMinMaxC = ConstantInt::get(II->getType(), CDiff); + Value *NewMinMax = Builder.CreateBinaryIntrinsic(MinMaxID, X, NewMinMaxC); + return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1)) + : BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1)); +} +/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value. +Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) { + Type *Ty = MinMax1.getType(); + + // We are looking for a tree of: + // max(INT_MIN, min(INT_MAX, add(sext(A), sext(B)))) + // Where the min and max could be reversed + Instruction *MinMax2; + BinaryOperator *AddSub; + const APInt *MinValue, *MaxValue; + if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) { + if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue)))) + return nullptr; + } else if (match(&MinMax1, + m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) { + if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue)))) + return nullptr; + } else + return nullptr; + + // Check that the constants clamp a saturate, and that the new type would be + // sensible to convert to. + if (!(*MaxValue + 1).isPowerOf2() || -*MinValue != *MaxValue + 1) + return nullptr; + // In what bitwidth can this be treated as saturating arithmetics? + unsigned NewBitWidth = (*MaxValue + 1).logBase2() + 1; + // FIXME: This isn't quite right for vectors, but using the scalar type is a + // good first approximation for what should be done there. + if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth)) + return nullptr; + + // Also make sure that the inner min/max and the add/sub have one use. + if (!MinMax2->hasOneUse() || !AddSub->hasOneUse()) + return nullptr; + + // Create the new type (which can be a vector type) + Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth); + + Intrinsic::ID IntrinsicID; + if (AddSub->getOpcode() == Instruction::Add) + IntrinsicID = Intrinsic::sadd_sat; + else if (AddSub->getOpcode() == Instruction::Sub) + IntrinsicID = Intrinsic::ssub_sat; + else + return nullptr; + + // The two operands of the add/sub must be nsw-truncatable to the NewTy. This + // is usually achieved via a sext from a smaller type. + if (ComputeMaxSignificantBits(AddSub->getOperand(0), 0, AddSub) > + NewBitWidth || + ComputeMaxSignificantBits(AddSub->getOperand(1), 0, AddSub) > NewBitWidth) + return nullptr; + + // Finally create and return the sat intrinsic, truncated to the new type + Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy); + Value *AT = Builder.CreateTrunc(AddSub->getOperand(0), NewTy); + Value *BT = Builder.CreateTrunc(AddSub->getOperand(1), NewTy); + Value *Sat = Builder.CreateCall(F, {AT, BT}); + return CastInst::Create(Instruction::SExt, Sat, Ty); +} + + +/// If we have a clamp pattern like max (min X, 42), 41 -- where the output +/// can only be one of two possible constant values -- turn that into a select +/// of constants. +static Instruction *foldClampRangeOfTwo(IntrinsicInst *II, + InstCombiner::BuilderTy &Builder) { + Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1); + Value *X; + const APInt *C0, *C1; + if (!match(I1, m_APInt(C1)) || !I0->hasOneUse()) + return nullptr; + + CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; + switch (II->getIntrinsicID()) { + case Intrinsic::smax: + if (match(I0, m_SMin(m_Value(X), m_APInt(C0))) && *C0 == *C1 + 1) + Pred = ICmpInst::ICMP_SGT; + break; + case Intrinsic::smin: + if (match(I0, m_SMax(m_Value(X), m_APInt(C0))) && *C1 == *C0 + 1) + Pred = ICmpInst::ICMP_SLT; + break; + case Intrinsic::umax: + if (match(I0, m_UMin(m_Value(X), m_APInt(C0))) && *C0 == *C1 + 1) + Pred = ICmpInst::ICMP_UGT; + break; + case Intrinsic::umin: + if (match(I0, m_UMax(m_Value(X), m_APInt(C0))) && *C1 == *C0 + 1) + Pred = ICmpInst::ICMP_ULT; + break; + default: + llvm_unreachable("Expected min/max intrinsic"); + } + if (Pred == CmpInst::BAD_ICMP_PREDICATE) + return nullptr; + + // max (min X, 42), 41 --> X > 41 ? 42 : 41 + // min (max X, 42), 43 --> X < 43 ? 42 : 43 + Value *Cmp = Builder.CreateICmp(Pred, X, I1); + return SelectInst::Create(Cmp, ConstantInt::get(II->getType(), *C0), I1); +} + +/// If this min/max has a constant operand and an operand that is a matching +/// min/max with a constant operand, constant-fold the 2 constant operands. +static Instruction *reassociateMinMaxWithConstants(IntrinsicInst *II) { + Intrinsic::ID MinMaxID = II->getIntrinsicID(); + auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0)); + if (!LHS || LHS->getIntrinsicID() != MinMaxID) + return nullptr; + + Constant *C0, *C1; + if (!match(LHS->getArgOperand(1), m_ImmConstant(C0)) || + !match(II->getArgOperand(1), m_ImmConstant(C1))) + return nullptr; + + // max (max X, C0), C1 --> max X, (max C0, C1) --> max X, NewC + ICmpInst::Predicate Pred = MinMaxIntrinsic::getPredicate(MinMaxID); + Constant *CondC = ConstantExpr::getICmp(Pred, C0, C1); + Constant *NewC = ConstantExpr::getSelect(CondC, C0, C1); + + Module *Mod = II->getModule(); + Function *MinMax = Intrinsic::getDeclaration(Mod, MinMaxID, II->getType()); + return CallInst::Create(MinMax, {LHS->getArgOperand(0), NewC}); +} + +/// If this min/max has a matching min/max operand with a constant, try to push +/// the constant operand into this instruction. This can enable more folds. +static Instruction * +reassociateMinMaxWithConstantInOperand(IntrinsicInst *II, + InstCombiner::BuilderTy &Builder) { + // Match and capture a min/max operand candidate. + Value *X, *Y; + Constant *C; + Instruction *Inner; + if (!match(II, m_c_MaxOrMin(m_OneUse(m_CombineAnd( + m_Instruction(Inner), + m_MaxOrMin(m_Value(X), m_ImmConstant(C)))), + m_Value(Y)))) + return nullptr; + + // The inner op must match. Check for constants to avoid infinite loops. + Intrinsic::ID MinMaxID = II->getIntrinsicID(); + auto *InnerMM = dyn_cast<IntrinsicInst>(Inner); + if (!InnerMM || InnerMM->getIntrinsicID() != MinMaxID || + match(X, m_ImmConstant()) || match(Y, m_ImmConstant())) + return nullptr; + + // max (max X, C), Y --> max (max X, Y), C + Function *MinMax = + Intrinsic::getDeclaration(II->getModule(), MinMaxID, II->getType()); + Value *NewInner = Builder.CreateBinaryIntrinsic(MinMaxID, X, Y); + NewInner->takeName(Inner); + return CallInst::Create(MinMax, {NewInner, C}); +} + +/// Reduce a sequence of min/max intrinsics with a common operand. +static Instruction *factorizeMinMaxTree(IntrinsicInst *II) { + // Match 3 of the same min/max ops. Example: umin(umin(), umin()). + auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0)); + auto *RHS = dyn_cast<IntrinsicInst>(II->getArgOperand(1)); + Intrinsic::ID MinMaxID = II->getIntrinsicID(); + if (!LHS || !RHS || LHS->getIntrinsicID() != MinMaxID || + RHS->getIntrinsicID() != MinMaxID || + (!LHS->hasOneUse() && !RHS->hasOneUse())) + return nullptr; + + Value *A = LHS->getArgOperand(0); + Value *B = LHS->getArgOperand(1); + Value *C = RHS->getArgOperand(0); + Value *D = RHS->getArgOperand(1); + + // Look for a common operand. + Value *MinMaxOp = nullptr; + Value *ThirdOp = nullptr; + if (LHS->hasOneUse()) { + // If the LHS is only used in this chain and the RHS is used outside of it, + // reuse the RHS min/max because that will eliminate the LHS. + if (D == A || C == A) { + // min(min(a, b), min(c, a)) --> min(min(c, a), b) + // min(min(a, b), min(a, d)) --> min(min(a, d), b) + MinMaxOp = RHS; + ThirdOp = B; + } else if (D == B || C == B) { + // min(min(a, b), min(c, b)) --> min(min(c, b), a) + // min(min(a, b), min(b, d)) --> min(min(b, d), a) + MinMaxOp = RHS; + ThirdOp = A; + } + } else { + assert(RHS->hasOneUse() && "Expected one-use operand"); + // Reuse the LHS. This will eliminate the RHS. + if (D == A || D == B) { + // min(min(a, b), min(c, a)) --> min(min(a, b), c) + // min(min(a, b), min(c, b)) --> min(min(a, b), c) + MinMaxOp = LHS; + ThirdOp = C; + } else if (C == A || C == B) { + // min(min(a, b), min(b, d)) --> min(min(a, b), d) + // min(min(a, b), min(c, b)) --> min(min(a, b), d) + MinMaxOp = LHS; + ThirdOp = D; + } + } + + if (!MinMaxOp || !ThirdOp) + return nullptr; + + Module *Mod = II->getModule(); + Function *MinMax = Intrinsic::getDeclaration(Mod, MinMaxID, II->getType()); + return CallInst::Create(MinMax, { MinMaxOp, ThirdOp }); +} + +/// If all arguments of the intrinsic are unary shuffles with the same mask, +/// try to shuffle after the intrinsic. +static Instruction * +foldShuffledIntrinsicOperands(IntrinsicInst *II, + InstCombiner::BuilderTy &Builder) { + // TODO: This should be extended to handle other intrinsics like fshl, ctpop, + // etc. Use llvm::isTriviallyVectorizable() and related to determine + // which intrinsics are safe to shuffle? + switch (II->getIntrinsicID()) { + case Intrinsic::smax: + case Intrinsic::smin: + case Intrinsic::umax: + case Intrinsic::umin: + case Intrinsic::fma: + case Intrinsic::fshl: + case Intrinsic::fshr: + break; + default: + return nullptr; + } + + Value *X; + ArrayRef<int> Mask; + if (!match(II->getArgOperand(0), + m_Shuffle(m_Value(X), m_Undef(), m_Mask(Mask)))) + return nullptr; + + // At least 1 operand must have 1 use because we are creating 2 instructions. + if (none_of(II->args(), [](Value *V) { return V->hasOneUse(); })) + return nullptr; + + // See if all arguments are shuffled with the same mask. + SmallVector<Value *, 4> NewArgs(II->arg_size()); + NewArgs[0] = X; + Type *SrcTy = X->getType(); + for (unsigned i = 1, e = II->arg_size(); i != e; ++i) { + if (!match(II->getArgOperand(i), + m_Shuffle(m_Value(X), m_Undef(), m_SpecificMask(Mask))) || + X->getType() != SrcTy) + return nullptr; + NewArgs[i] = X; + } + + // intrinsic (shuf X, M), (shuf Y, M), ... --> shuf (intrinsic X, Y, ...), M + Instruction *FPI = isa<FPMathOperator>(II) ? II : nullptr; + Value *NewIntrinsic = + Builder.CreateIntrinsic(II->getIntrinsicID(), SrcTy, NewArgs, FPI); + return new ShuffleVectorInst(NewIntrinsic, Mask); +} + +/// CallInst simplification. This mostly only handles folding of intrinsic +/// instructions. For normal calls, it allows visitCallBase to do the heavy +/// lifting. +Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { + // Don't try to simplify calls without uses. It will not do anything useful, + // but will result in the following folds being skipped. + if (!CI.use_empty()) + if (Value *V = simplifyCall(&CI, SQ.getWithInstruction(&CI))) + return replaceInstUsesWith(CI, V); + + if (isFreeCall(&CI, &TLI)) + return visitFree(CI); + + // If the caller function (i.e. us, the function that contains this CallInst) + // is nounwind, mark the call as nounwind, even if the callee isn't. + if (CI.getFunction()->doesNotThrow() && !CI.doesNotThrow()) { + CI.setDoesNotThrow(); + return &CI; + } + + IntrinsicInst *II = dyn_cast<IntrinsicInst>(&CI); + if (!II) return visitCallBase(CI); + + // For atomic unordered mem intrinsics if len is not a positive or + // not a multiple of element size then behavior is undefined. + if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(II)) + if (ConstantInt *NumBytes = dyn_cast<ConstantInt>(AMI->getLength())) + if (NumBytes->getSExtValue() < 0 || + (NumBytes->getZExtValue() % AMI->getElementSizeInBytes() != 0)) { + CreateNonTerminatorUnreachable(AMI); + assert(AMI->getType()->isVoidTy() && + "non void atomic unordered mem intrinsic"); + return eraseInstFromFunction(*AMI); + } + + // Intrinsics cannot occur in an invoke or a callbr, so handle them here + // instead of in visitCallBase. + if (auto *MI = dyn_cast<AnyMemIntrinsic>(II)) { + bool Changed = false; + + // memmove/cpy/set of zero bytes is a noop. + if (Constant *NumBytes = dyn_cast<Constant>(MI->getLength())) { + if (NumBytes->isNullValue()) + return eraseInstFromFunction(CI); + } + + // No other transformations apply to volatile transfers. + if (auto *M = dyn_cast<MemIntrinsic>(MI)) + if (M->isVolatile()) + return nullptr; + + // If we have a memmove and the source operation is a constant global, + // then the source and dest pointers can't alias, so we can change this + // into a call to memcpy. + if (auto *MMI = dyn_cast<AnyMemMoveInst>(MI)) { + if (GlobalVariable *GVSrc = dyn_cast<GlobalVariable>(MMI->getSource())) + if (GVSrc->isConstant()) { + Module *M = CI.getModule(); + Intrinsic::ID MemCpyID = + isa<AtomicMemMoveInst>(MMI) + ? Intrinsic::memcpy_element_unordered_atomic + : Intrinsic::memcpy; + Type *Tys[3] = { CI.getArgOperand(0)->getType(), + CI.getArgOperand(1)->getType(), + CI.getArgOperand(2)->getType() }; + CI.setCalledFunction(Intrinsic::getDeclaration(M, MemCpyID, Tys)); + Changed = true; + } + } + + if (AnyMemTransferInst *MTI = dyn_cast<AnyMemTransferInst>(MI)) { + // memmove(x,x,size) -> noop. + if (MTI->getSource() == MTI->getDest()) + return eraseInstFromFunction(CI); + } + + // If we can determine a pointer alignment that is bigger than currently + // set, update the alignment. + if (auto *MTI = dyn_cast<AnyMemTransferInst>(MI)) { + if (Instruction *I = SimplifyAnyMemTransfer(MTI)) + return I; + } else if (auto *MSI = dyn_cast<AnyMemSetInst>(MI)) { + if (Instruction *I = SimplifyAnyMemSet(MSI)) + return I; + } + + if (Changed) return II; + } + + // For fixed width vector result intrinsics, use the generic demanded vector + // support. + if (auto *IIFVTy = dyn_cast<FixedVectorType>(II->getType())) { + auto VWidth = IIFVTy->getNumElements(); + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) { + if (V != II) + return replaceInstUsesWith(*II, V); + return II; + } + } + + if (II->isCommutative()) { + if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI)) + return NewCall; + } + + // Unused constrained FP intrinsic calls may have declared side effect, which + // prevents it from being removed. In some cases however the side effect is + // actually absent. To detect this case, call SimplifyConstrainedFPCall. If it + // returns a replacement, the call may be removed. + if (CI.use_empty() && isa<ConstrainedFPIntrinsic>(CI)) { + if (simplifyConstrainedFPCall(&CI, SQ.getWithInstruction(&CI))) + return eraseInstFromFunction(CI); + } + + Intrinsic::ID IID = II->getIntrinsicID(); + switch (IID) { + case Intrinsic::objectsize: + if (Value *V = lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/false)) + return replaceInstUsesWith(CI, V); + return nullptr; + case Intrinsic::abs: { + Value *IIOperand = II->getArgOperand(0); + bool IntMinIsPoison = cast<Constant>(II->getArgOperand(1))->isOneValue(); + + // abs(-x) -> abs(x) + // TODO: Copy nsw if it was present on the neg? + Value *X; + if (match(IIOperand, m_Neg(m_Value(X)))) + return replaceOperand(*II, 0, X); + if (match(IIOperand, m_Select(m_Value(), m_Value(X), m_Neg(m_Deferred(X))))) + return replaceOperand(*II, 0, X); + if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X)))) + return replaceOperand(*II, 0, X); + + if (Optional<bool> Sign = getKnownSign(IIOperand, II, DL, &AC, &DT)) { + // abs(x) -> x if x >= 0 + if (!*Sign) + return replaceInstUsesWith(*II, IIOperand); + + // abs(x) -> -x if x < 0 + if (IntMinIsPoison) + return BinaryOperator::CreateNSWNeg(IIOperand); + return BinaryOperator::CreateNeg(IIOperand); + } + + // abs (sext X) --> zext (abs X*) + // Clear the IsIntMin (nsw) bit on the abs to allow narrowing. + if (match(IIOperand, m_OneUse(m_SExt(m_Value(X))))) { + Value *NarrowAbs = + Builder.CreateBinaryIntrinsic(Intrinsic::abs, X, Builder.getFalse()); + return CastInst::Create(Instruction::ZExt, NarrowAbs, II->getType()); + } + + // Match a complicated way to check if a number is odd/even: + // abs (srem X, 2) --> and X, 1 + const APInt *C; + if (match(IIOperand, m_SRem(m_Value(X), m_APInt(C))) && *C == 2) + return BinaryOperator::CreateAnd(X, ConstantInt::get(II->getType(), 1)); + + break; + } + case Intrinsic::umin: { + Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1); + // umin(x, 1) == zext(x != 0) + if (match(I1, m_One())) { + Value *Zero = Constant::getNullValue(I0->getType()); + Value *Cmp = Builder.CreateICmpNE(I0, Zero); + return CastInst::Create(Instruction::ZExt, Cmp, II->getType()); + } + LLVM_FALLTHROUGH; + } + case Intrinsic::umax: { + Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1); + Value *X, *Y; + if (match(I0, m_ZExt(m_Value(X))) && match(I1, m_ZExt(m_Value(Y))) && + (I0->hasOneUse() || I1->hasOneUse()) && X->getType() == Y->getType()) { + Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, Y); + return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType()); + } + Constant *C; + if (match(I0, m_ZExt(m_Value(X))) && match(I1, m_Constant(C)) && + I0->hasOneUse()) { + Constant *NarrowC = ConstantExpr::getTrunc(C, X->getType()); + if (ConstantExpr::getZExt(NarrowC, II->getType()) == C) { + Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC); + return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType()); + } + } + // If both operands of unsigned min/max are sign-extended, it is still ok + // to narrow the operation. + LLVM_FALLTHROUGH; + } + case Intrinsic::smax: + case Intrinsic::smin: { + Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1); + Value *X, *Y; + if (match(I0, m_SExt(m_Value(X))) && match(I1, m_SExt(m_Value(Y))) && + (I0->hasOneUse() || I1->hasOneUse()) && X->getType() == Y->getType()) { + Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, Y); + return CastInst::Create(Instruction::SExt, NarrowMaxMin, II->getType()); + } + + Constant *C; + if (match(I0, m_SExt(m_Value(X))) && match(I1, m_Constant(C)) && + I0->hasOneUse()) { + Constant *NarrowC = ConstantExpr::getTrunc(C, X->getType()); + if (ConstantExpr::getSExt(NarrowC, II->getType()) == C) { + Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC); + return CastInst::Create(Instruction::SExt, NarrowMaxMin, II->getType()); + } + } + + if (IID == Intrinsic::smax || IID == Intrinsic::smin) { + // smax (neg nsw X), (neg nsw Y) --> neg nsw (smin X, Y) + // smin (neg nsw X), (neg nsw Y) --> neg nsw (smax X, Y) + // TODO: Canonicalize neg after min/max if I1 is constant. + if (match(I0, m_NSWNeg(m_Value(X))) && match(I1, m_NSWNeg(m_Value(Y))) && + (I0->hasOneUse() || I1->hasOneUse())) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, Y); + return BinaryOperator::CreateNSWNeg(InvMaxMin); + } + } + + // If we can eliminate ~A and Y is free to invert: + // max ~A, Y --> ~(min A, ~Y) + // + // Examples: + // max ~A, ~Y --> ~(min A, Y) + // max ~A, C --> ~(min A, ~C) + // max ~A, (max ~Y, ~Z) --> ~min( A, (min Y, Z)) + auto moveNotAfterMinMax = [&](Value *X, Value *Y) -> Instruction * { + Value *A; + if (match(X, m_OneUse(m_Not(m_Value(A)))) && + !isFreeToInvert(A, A->hasOneUse()) && + isFreeToInvert(Y, Y->hasOneUse())) { + Value *NotY = Builder.CreateNot(Y); + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, A, NotY); + return BinaryOperator::CreateNot(InvMaxMin); + } + return nullptr; + }; + + if (Instruction *I = moveNotAfterMinMax(I0, I1)) + return I; + if (Instruction *I = moveNotAfterMinMax(I1, I0)) + return I; + + if (Instruction *I = moveAddAfterMinMax(II, Builder)) + return I; + + // smax(X, -X) --> abs(X) + // smin(X, -X) --> -abs(X) + // umax(X, -X) --> -abs(X) + // umin(X, -X) --> abs(X) + if (isKnownNegation(I0, I1)) { + // We can choose either operand as the input to abs(), but if we can + // eliminate the only use of a value, that's better for subsequent + // transforms/analysis. + if (I0->hasOneUse() && !I1->hasOneUse()) + std::swap(I0, I1); + + // This is some variant of abs(). See if we can propagate 'nsw' to the abs + // operation and potentially its negation. + bool IntMinIsPoison = isKnownNegation(I0, I1, /* NeedNSW */ true); + Value *Abs = Builder.CreateBinaryIntrinsic( + Intrinsic::abs, I0, + ConstantInt::getBool(II->getContext(), IntMinIsPoison)); + + // We don't have a "nabs" intrinsic, so negate if needed based on the + // max/min operation. + if (IID == Intrinsic::smin || IID == Intrinsic::umax) + Abs = Builder.CreateNeg(Abs, "nabs", /* NUW */ false, IntMinIsPoison); + return replaceInstUsesWith(CI, Abs); + } + + if (Instruction *Sel = foldClampRangeOfTwo(II, Builder)) + return Sel; + + if (Instruction *SAdd = matchSAddSubSat(*II)) + return SAdd; + + if (match(I1, m_ImmConstant())) + if (auto *Sel = dyn_cast<SelectInst>(I0)) + if (Instruction *R = FoldOpIntoSelect(*II, Sel)) + return R; + + if (Instruction *NewMinMax = reassociateMinMaxWithConstants(II)) + return NewMinMax; + + if (Instruction *R = reassociateMinMaxWithConstantInOperand(II, Builder)) + return R; + + if (Instruction *NewMinMax = factorizeMinMaxTree(II)) + return NewMinMax; + + break; + } + case Intrinsic::bswap: { + Value *IIOperand = II->getArgOperand(0); + + // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as + // inverse-shift-of-bswap: + // bswap (shl X, Y) --> lshr (bswap X), Y + // bswap (lshr X, Y) --> shl (bswap X), Y + Value *X, *Y; + if (match(IIOperand, m_OneUse(m_LogicalShift(m_Value(X), m_Value(Y))))) { + // The transform allows undef vector elements, so try a constant match + // first. If knownbits can handle that case, that clause could be removed. + unsigned BitWidth = IIOperand->getType()->getScalarSizeInBits(); + const APInt *C; + if ((match(Y, m_APIntAllowUndef(C)) && (*C & 7) == 0) || + MaskedValueIsZero(Y, APInt::getLowBitsSet(BitWidth, 3))) { + Value *NewSwap = Builder.CreateUnaryIntrinsic(Intrinsic::bswap, X); + BinaryOperator::BinaryOps InverseShift = + cast<BinaryOperator>(IIOperand)->getOpcode() == Instruction::Shl + ? Instruction::LShr + : Instruction::Shl; + return BinaryOperator::Create(InverseShift, NewSwap, Y); + } + } + + KnownBits Known = computeKnownBits(IIOperand, 0, II); + uint64_t LZ = alignDown(Known.countMinLeadingZeros(), 8); + uint64_t TZ = alignDown(Known.countMinTrailingZeros(), 8); + unsigned BW = Known.getBitWidth(); + + // bswap(x) -> shift(x) if x has exactly one "active byte" + if (BW - LZ - TZ == 8) { + assert(LZ != TZ && "active byte cannot be in the middle"); + if (LZ > TZ) // -> shl(x) if the "active byte" is in the low part of x + return BinaryOperator::CreateNUWShl( + IIOperand, ConstantInt::get(IIOperand->getType(), LZ - TZ)); + // -> lshr(x) if the "active byte" is in the high part of x + return BinaryOperator::CreateExactLShr( + IIOperand, ConstantInt::get(IIOperand->getType(), TZ - LZ)); + } + + // bswap(trunc(bswap(x))) -> trunc(lshr(x, c)) + if (match(IIOperand, m_Trunc(m_BSwap(m_Value(X))))) { + unsigned C = X->getType()->getScalarSizeInBits() - BW; + Value *CV = ConstantInt::get(X->getType(), C); + Value *V = Builder.CreateLShr(X, CV); + return new TruncInst(V, IIOperand->getType()); + } + break; + } + case Intrinsic::masked_load: + if (Value *SimplifiedMaskedOp = simplifyMaskedLoad(*II)) + return replaceInstUsesWith(CI, SimplifiedMaskedOp); + break; + case Intrinsic::masked_store: + return simplifyMaskedStore(*II); + case Intrinsic::masked_gather: + return simplifyMaskedGather(*II); + case Intrinsic::masked_scatter: + return simplifyMaskedScatter(*II); + case Intrinsic::launder_invariant_group: + case Intrinsic::strip_invariant_group: + if (auto *SkippedBarrier = simplifyInvariantGroupIntrinsic(*II, *this)) + return replaceInstUsesWith(*II, SkippedBarrier); + break; + case Intrinsic::powi: + if (ConstantInt *Power = dyn_cast<ConstantInt>(II->getArgOperand(1))) { + // 0 and 1 are handled in instsimplify + // powi(x, -1) -> 1/x + if (Power->isMinusOne()) + return BinaryOperator::CreateFDivFMF(ConstantFP::get(CI.getType(), 1.0), + II->getArgOperand(0), II); + // powi(x, 2) -> x*x + if (Power->equalsInt(2)) + return BinaryOperator::CreateFMulFMF(II->getArgOperand(0), + II->getArgOperand(0), II); + + if (!Power->getValue()[0]) { + Value *X; + // If power is even: + // powi(-x, p) -> powi(x, p) + // powi(fabs(x), p) -> powi(x, p) + // powi(copysign(x, y), p) -> powi(x, p) + if (match(II->getArgOperand(0), m_FNeg(m_Value(X))) || + match(II->getArgOperand(0), m_FAbs(m_Value(X))) || + match(II->getArgOperand(0), + m_Intrinsic<Intrinsic::copysign>(m_Value(X), m_Value()))) + return replaceOperand(*II, 0, X); + } + } + break; + + case Intrinsic::cttz: + case Intrinsic::ctlz: + if (auto *I = foldCttzCtlz(*II, *this)) + return I; + break; + + case Intrinsic::ctpop: + if (auto *I = foldCtpop(*II, *this)) + return I; + break; + + case Intrinsic::fshl: + case Intrinsic::fshr: { + Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1); + Type *Ty = II->getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + Constant *ShAmtC; + if (match(II->getArgOperand(2), m_ImmConstant(ShAmtC)) && + !ShAmtC->containsConstantExpression()) { + // Canonicalize a shift amount constant operand to modulo the bit-width. + Constant *WidthC = ConstantInt::get(Ty, BitWidth); + Constant *ModuloC = ConstantExpr::getURem(ShAmtC, WidthC); + if (ModuloC != ShAmtC) + return replaceOperand(*II, 2, ModuloC); + + assert(ConstantExpr::getICmp(ICmpInst::ICMP_UGT, WidthC, ShAmtC) == + ConstantInt::getTrue(CmpInst::makeCmpResultType(Ty)) && + "Shift amount expected to be modulo bitwidth"); + + // Canonicalize funnel shift right by constant to funnel shift left. This + // is not entirely arbitrary. For historical reasons, the backend may + // recognize rotate left patterns but miss rotate right patterns. + if (IID == Intrinsic::fshr) { + // fshr X, Y, C --> fshl X, Y, (BitWidth - C) + Constant *LeftShiftC = ConstantExpr::getSub(WidthC, ShAmtC); + Module *Mod = II->getModule(); + Function *Fshl = Intrinsic::getDeclaration(Mod, Intrinsic::fshl, Ty); + return CallInst::Create(Fshl, { Op0, Op1, LeftShiftC }); + } + assert(IID == Intrinsic::fshl && + "All funnel shifts by simple constants should go left"); + + // fshl(X, 0, C) --> shl X, C + // fshl(X, undef, C) --> shl X, C + if (match(Op1, m_ZeroInt()) || match(Op1, m_Undef())) + return BinaryOperator::CreateShl(Op0, ShAmtC); + + // fshl(0, X, C) --> lshr X, (BW-C) + // fshl(undef, X, C) --> lshr X, (BW-C) + if (match(Op0, m_ZeroInt()) || match(Op0, m_Undef())) + return BinaryOperator::CreateLShr(Op1, + ConstantExpr::getSub(WidthC, ShAmtC)); + + // fshl i16 X, X, 8 --> bswap i16 X (reduce to more-specific form) + if (Op0 == Op1 && BitWidth == 16 && match(ShAmtC, m_SpecificInt(8))) { + Module *Mod = II->getModule(); + Function *Bswap = Intrinsic::getDeclaration(Mod, Intrinsic::bswap, Ty); + return CallInst::Create(Bswap, { Op0 }); + } + } + + // Left or right might be masked. + if (SimplifyDemandedInstructionBits(*II)) + return &CI; + + // The shift amount (operand 2) of a funnel shift is modulo the bitwidth, + // so only the low bits of the shift amount are demanded if the bitwidth is + // a power-of-2. + if (!isPowerOf2_32(BitWidth)) + break; + APInt Op2Demanded = APInt::getLowBitsSet(BitWidth, Log2_32_Ceil(BitWidth)); + KnownBits Op2Known(BitWidth); + if (SimplifyDemandedBits(II, 2, Op2Demanded, Op2Known)) + return &CI; + break; + } + case Intrinsic::uadd_with_overflow: + case Intrinsic::sadd_with_overflow: { + if (Instruction *I = foldIntrinsicWithOverflowCommon(II)) + return I; + + // Given 2 constant operands whose sum does not overflow: + // uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1 + // saddo (X +nsw C0), C1 -> saddo X, C0 + C1 + Value *X; + const APInt *C0, *C1; + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + bool IsSigned = IID == Intrinsic::sadd_with_overflow; + bool HasNWAdd = IsSigned ? match(Arg0, m_NSWAdd(m_Value(X), m_APInt(C0))) + : match(Arg0, m_NUWAdd(m_Value(X), m_APInt(C0))); + if (HasNWAdd && match(Arg1, m_APInt(C1))) { + bool Overflow; + APInt NewC = + IsSigned ? C1->sadd_ov(*C0, Overflow) : C1->uadd_ov(*C0, Overflow); + if (!Overflow) + return replaceInstUsesWith( + *II, Builder.CreateBinaryIntrinsic( + IID, X, ConstantInt::get(Arg1->getType(), NewC))); + } + break; + } + + case Intrinsic::umul_with_overflow: + case Intrinsic::smul_with_overflow: + case Intrinsic::usub_with_overflow: + if (Instruction *I = foldIntrinsicWithOverflowCommon(II)) + return I; + break; + + case Intrinsic::ssub_with_overflow: { + if (Instruction *I = foldIntrinsicWithOverflowCommon(II)) + return I; + + Constant *C; + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + // Given a constant C that is not the minimum signed value + // for an integer of a given bit width: + // + // ssubo X, C -> saddo X, -C + if (match(Arg1, m_Constant(C)) && C->isNotMinSignedValue()) { + Value *NegVal = ConstantExpr::getNeg(C); + // Build a saddo call that is equivalent to the discovered + // ssubo call. + return replaceInstUsesWith( + *II, Builder.CreateBinaryIntrinsic(Intrinsic::sadd_with_overflow, + Arg0, NegVal)); + } + + break; + } + + case Intrinsic::uadd_sat: + case Intrinsic::sadd_sat: + case Intrinsic::usub_sat: + case Intrinsic::ssub_sat: { + SaturatingInst *SI = cast<SaturatingInst>(II); + Type *Ty = SI->getType(); + Value *Arg0 = SI->getLHS(); + Value *Arg1 = SI->getRHS(); + + // Make use of known overflow information. + OverflowResult OR = computeOverflow(SI->getBinaryOp(), SI->isSigned(), + Arg0, Arg1, SI); + switch (OR) { + case OverflowResult::MayOverflow: + break; + case OverflowResult::NeverOverflows: + if (SI->isSigned()) + return BinaryOperator::CreateNSW(SI->getBinaryOp(), Arg0, Arg1); + else + return BinaryOperator::CreateNUW(SI->getBinaryOp(), Arg0, Arg1); + case OverflowResult::AlwaysOverflowsLow: { + unsigned BitWidth = Ty->getScalarSizeInBits(); + APInt Min = APSInt::getMinValue(BitWidth, !SI->isSigned()); + return replaceInstUsesWith(*SI, ConstantInt::get(Ty, Min)); + } + case OverflowResult::AlwaysOverflowsHigh: { + unsigned BitWidth = Ty->getScalarSizeInBits(); + APInt Max = APSInt::getMaxValue(BitWidth, !SI->isSigned()); + return replaceInstUsesWith(*SI, ConstantInt::get(Ty, Max)); + } + } + + // ssub.sat(X, C) -> sadd.sat(X, -C) if C != MIN + Constant *C; + if (IID == Intrinsic::ssub_sat && match(Arg1, m_Constant(C)) && + C->isNotMinSignedValue()) { + Value *NegVal = ConstantExpr::getNeg(C); + return replaceInstUsesWith( + *II, Builder.CreateBinaryIntrinsic( + Intrinsic::sadd_sat, Arg0, NegVal)); + } + + // sat(sat(X + Val2) + Val) -> sat(X + (Val+Val2)) + // sat(sat(X - Val2) - Val) -> sat(X - (Val+Val2)) + // if Val and Val2 have the same sign + if (auto *Other = dyn_cast<IntrinsicInst>(Arg0)) { + Value *X; + const APInt *Val, *Val2; + APInt NewVal; + bool IsUnsigned = + IID == Intrinsic::uadd_sat || IID == Intrinsic::usub_sat; + if (Other->getIntrinsicID() == IID && + match(Arg1, m_APInt(Val)) && + match(Other->getArgOperand(0), m_Value(X)) && + match(Other->getArgOperand(1), m_APInt(Val2))) { + if (IsUnsigned) + NewVal = Val->uadd_sat(*Val2); + else if (Val->isNonNegative() == Val2->isNonNegative()) { + bool Overflow; + NewVal = Val->sadd_ov(*Val2, Overflow); + if (Overflow) { + // Both adds together may add more than SignedMaxValue + // without saturating the final result. + break; + } + } else { + // Cannot fold saturated addition with different signs. + break; + } + + return replaceInstUsesWith( + *II, Builder.CreateBinaryIntrinsic( + IID, X, ConstantInt::get(II->getType(), NewVal))); + } + } + break; + } + + case Intrinsic::minnum: + case Intrinsic::maxnum: + case Intrinsic::minimum: + case Intrinsic::maximum: { + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + Value *X, *Y; + if (match(Arg0, m_FNeg(m_Value(X))) && match(Arg1, m_FNeg(m_Value(Y))) && + (Arg0->hasOneUse() || Arg1->hasOneUse())) { + // If both operands are negated, invert the call and negate the result: + // min(-X, -Y) --> -(max(X, Y)) + // max(-X, -Y) --> -(min(X, Y)) + Intrinsic::ID NewIID; + switch (IID) { + case Intrinsic::maxnum: + NewIID = Intrinsic::minnum; + break; + case Intrinsic::minnum: + NewIID = Intrinsic::maxnum; + break; + case Intrinsic::maximum: + NewIID = Intrinsic::minimum; + break; + case Intrinsic::minimum: + NewIID = Intrinsic::maximum; + break; + default: + llvm_unreachable("unexpected intrinsic ID"); + } + Value *NewCall = Builder.CreateBinaryIntrinsic(NewIID, X, Y, II); + Instruction *FNeg = UnaryOperator::CreateFNeg(NewCall); + FNeg->copyIRFlags(II); + return FNeg; + } + + // m(m(X, C2), C1) -> m(X, C) + const APFloat *C1, *C2; + if (auto *M = dyn_cast<IntrinsicInst>(Arg0)) { + if (M->getIntrinsicID() == IID && match(Arg1, m_APFloat(C1)) && + ((match(M->getArgOperand(0), m_Value(X)) && + match(M->getArgOperand(1), m_APFloat(C2))) || + (match(M->getArgOperand(1), m_Value(X)) && + match(M->getArgOperand(0), m_APFloat(C2))))) { + APFloat Res(0.0); + switch (IID) { + case Intrinsic::maxnum: + Res = maxnum(*C1, *C2); + break; + case Intrinsic::minnum: + Res = minnum(*C1, *C2); + break; + case Intrinsic::maximum: + Res = maximum(*C1, *C2); + break; + case Intrinsic::minimum: + Res = minimum(*C1, *C2); + break; + default: + llvm_unreachable("unexpected intrinsic ID"); + } + Instruction *NewCall = Builder.CreateBinaryIntrinsic( + IID, X, ConstantFP::get(Arg0->getType(), Res), II); + // TODO: Conservatively intersecting FMF. If Res == C2, the transform + // was a simplification (so Arg0 and its original flags could + // propagate?) + NewCall->andIRFlags(M); + return replaceInstUsesWith(*II, NewCall); + } + } + + // m((fpext X), (fpext Y)) -> fpext (m(X, Y)) + if (match(Arg0, m_OneUse(m_FPExt(m_Value(X)))) && + match(Arg1, m_OneUse(m_FPExt(m_Value(Y)))) && + X->getType() == Y->getType()) { + Value *NewCall = + Builder.CreateBinaryIntrinsic(IID, X, Y, II, II->getName()); + return new FPExtInst(NewCall, II->getType()); + } + + // max X, -X --> fabs X + // min X, -X --> -(fabs X) + // TODO: Remove one-use limitation? That is obviously better for max. + // It would be an extra instruction for min (fnabs), but that is + // still likely better for analysis and codegen. + if ((match(Arg0, m_OneUse(m_FNeg(m_Value(X)))) && Arg1 == X) || + (match(Arg1, m_OneUse(m_FNeg(m_Value(X)))) && Arg0 == X)) { + Value *R = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, II); + if (IID == Intrinsic::minimum || IID == Intrinsic::minnum) + R = Builder.CreateFNegFMF(R, II); + return replaceInstUsesWith(*II, R); + } + + break; + } + case Intrinsic::fmuladd: { + // Canonicalize fast fmuladd to the separate fmul + fadd. + if (II->isFast()) { + BuilderTy::FastMathFlagGuard Guard(Builder); + Builder.setFastMathFlags(II->getFastMathFlags()); + Value *Mul = Builder.CreateFMul(II->getArgOperand(0), + II->getArgOperand(1)); + Value *Add = Builder.CreateFAdd(Mul, II->getArgOperand(2)); + Add->takeName(II); + return replaceInstUsesWith(*II, Add); + } + + // Try to simplify the underlying FMul. + if (Value *V = simplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1), + II->getFastMathFlags(), + SQ.getWithInstruction(II))) { + auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2)); + FAdd->copyFastMathFlags(II); + return FAdd; + } + + LLVM_FALLTHROUGH; + } + case Intrinsic::fma: { + // fma fneg(x), fneg(y), z -> fma x, y, z + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); + Value *X, *Y; + if (match(Src0, m_FNeg(m_Value(X))) && match(Src1, m_FNeg(m_Value(Y)))) { + replaceOperand(*II, 0, X); + replaceOperand(*II, 1, Y); + return II; + } + + // fma fabs(x), fabs(x), z -> fma x, x, z + if (match(Src0, m_FAbs(m_Value(X))) && + match(Src1, m_FAbs(m_Specific(X)))) { + replaceOperand(*II, 0, X); + replaceOperand(*II, 1, X); + return II; + } + + // Try to simplify the underlying FMul. We can only apply simplifications + // that do not require rounding. + if (Value *V = simplifyFMAFMul(II->getArgOperand(0), II->getArgOperand(1), + II->getFastMathFlags(), + SQ.getWithInstruction(II))) { + auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2)); + FAdd->copyFastMathFlags(II); + return FAdd; + } + + // fma x, y, 0 -> fmul x, y + // This is always valid for -0.0, but requires nsz for +0.0 as + // -0.0 + 0.0 = 0.0, which would not be the same as the fmul on its own. + if (match(II->getArgOperand(2), m_NegZeroFP()) || + (match(II->getArgOperand(2), m_PosZeroFP()) && + II->getFastMathFlags().noSignedZeros())) + return BinaryOperator::CreateFMulFMF(Src0, Src1, II); + + break; + } + case Intrinsic::copysign: { + Value *Mag = II->getArgOperand(0), *Sign = II->getArgOperand(1); + if (SignBitMustBeZero(Sign, &TLI)) { + // If we know that the sign argument is positive, reduce to FABS: + // copysign Mag, +Sign --> fabs Mag + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Mag, II); + return replaceInstUsesWith(*II, Fabs); + } + // TODO: There should be a ValueTracking sibling like SignBitMustBeOne. + const APFloat *C; + if (match(Sign, m_APFloat(C)) && C->isNegative()) { + // If we know that the sign argument is negative, reduce to FNABS: + // copysign Mag, -Sign --> fneg (fabs Mag) + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Mag, II); + return replaceInstUsesWith(*II, Builder.CreateFNegFMF(Fabs, II)); + } + + // Propagate sign argument through nested calls: + // copysign Mag, (copysign ?, X) --> copysign Mag, X + Value *X; + if (match(Sign, m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(X)))) + return replaceOperand(*II, 1, X); + + // Peek through changes of magnitude's sign-bit. This call rewrites those: + // copysign (fabs X), Sign --> copysign X, Sign + // copysign (fneg X), Sign --> copysign X, Sign + if (match(Mag, m_FAbs(m_Value(X))) || match(Mag, m_FNeg(m_Value(X)))) + return replaceOperand(*II, 0, X); + + break; + } + case Intrinsic::fabs: { + Value *Cond, *TVal, *FVal; + if (match(II->getArgOperand(0), + m_Select(m_Value(Cond), m_Value(TVal), m_Value(FVal)))) { + // fabs (select Cond, TrueC, FalseC) --> select Cond, AbsT, AbsF + if (isa<Constant>(TVal) && isa<Constant>(FVal)) { + CallInst *AbsT = Builder.CreateCall(II->getCalledFunction(), {TVal}); + CallInst *AbsF = Builder.CreateCall(II->getCalledFunction(), {FVal}); + return SelectInst::Create(Cond, AbsT, AbsF); + } + // fabs (select Cond, -FVal, FVal) --> fabs FVal + if (match(TVal, m_FNeg(m_Specific(FVal)))) + return replaceOperand(*II, 0, FVal); + // fabs (select Cond, TVal, -TVal) --> fabs TVal + if (match(FVal, m_FNeg(m_Specific(TVal)))) + return replaceOperand(*II, 0, TVal); + } + + LLVM_FALLTHROUGH; + } + case Intrinsic::ceil: + case Intrinsic::floor: + case Intrinsic::round: + case Intrinsic::roundeven: + case Intrinsic::nearbyint: + case Intrinsic::rint: + case Intrinsic::trunc: { + Value *ExtSrc; + if (match(II->getArgOperand(0), m_OneUse(m_FPExt(m_Value(ExtSrc))))) { + // Narrow the call: intrinsic (fpext x) -> fpext (intrinsic x) + Value *NarrowII = Builder.CreateUnaryIntrinsic(IID, ExtSrc, II); + return new FPExtInst(NarrowII, II->getType()); + } + break; + } + case Intrinsic::cos: + case Intrinsic::amdgcn_cos: { + Value *X; + Value *Src = II->getArgOperand(0); + if (match(Src, m_FNeg(m_Value(X))) || match(Src, m_FAbs(m_Value(X)))) { + // cos(-x) -> cos(x) + // cos(fabs(x)) -> cos(x) + return replaceOperand(*II, 0, X); + } + break; + } + case Intrinsic::sin: { + Value *X; + if (match(II->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X))))) { + // sin(-x) --> -sin(x) + Value *NewSin = Builder.CreateUnaryIntrinsic(Intrinsic::sin, X, II); + Instruction *FNeg = UnaryOperator::CreateFNeg(NewSin); + FNeg->copyFastMathFlags(II); + return FNeg; + } + break; + } + + case Intrinsic::arm_neon_vtbl1: + case Intrinsic::aarch64_neon_tbl1: + if (Value *V = simplifyNeonTbl1(*II, Builder)) + return replaceInstUsesWith(*II, V); + break; + + case Intrinsic::arm_neon_vmulls: + case Intrinsic::arm_neon_vmullu: + case Intrinsic::aarch64_neon_smull: + case Intrinsic::aarch64_neon_umull: { + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + + // Handle mul by zero first: + if (isa<ConstantAggregateZero>(Arg0) || isa<ConstantAggregateZero>(Arg1)) { + return replaceInstUsesWith(CI, ConstantAggregateZero::get(II->getType())); + } + + // Check for constant LHS & RHS - in this case we just simplify. + bool Zext = (IID == Intrinsic::arm_neon_vmullu || + IID == Intrinsic::aarch64_neon_umull); + VectorType *NewVT = cast<VectorType>(II->getType()); + if (Constant *CV0 = dyn_cast<Constant>(Arg0)) { + if (Constant *CV1 = dyn_cast<Constant>(Arg1)) { + CV0 = ConstantExpr::getIntegerCast(CV0, NewVT, /*isSigned=*/!Zext); + CV1 = ConstantExpr::getIntegerCast(CV1, NewVT, /*isSigned=*/!Zext); + + return replaceInstUsesWith(CI, ConstantExpr::getMul(CV0, CV1)); + } + + // Couldn't simplify - canonicalize constant to the RHS. + std::swap(Arg0, Arg1); + } + + // Handle mul by one: + if (Constant *CV1 = dyn_cast<Constant>(Arg1)) + if (ConstantInt *Splat = + dyn_cast_or_null<ConstantInt>(CV1->getSplatValue())) + if (Splat->isOne()) + return CastInst::CreateIntegerCast(Arg0, II->getType(), + /*isSigned=*/!Zext); + + break; + } + case Intrinsic::arm_neon_aesd: + case Intrinsic::arm_neon_aese: + case Intrinsic::aarch64_crypto_aesd: + case Intrinsic::aarch64_crypto_aese: { + Value *DataArg = II->getArgOperand(0); + Value *KeyArg = II->getArgOperand(1); + + // Try to use the builtin XOR in AESE and AESD to eliminate a prior XOR + Value *Data, *Key; + if (match(KeyArg, m_ZeroInt()) && + match(DataArg, m_Xor(m_Value(Data), m_Value(Key)))) { + replaceOperand(*II, 0, Data); + replaceOperand(*II, 1, Key); + return II; + } + break; + } + case Intrinsic::hexagon_V6_vandvrt: + case Intrinsic::hexagon_V6_vandvrt_128B: { + // Simplify Q -> V -> Q conversion. + if (auto Op0 = dyn_cast<IntrinsicInst>(II->getArgOperand(0))) { + Intrinsic::ID ID0 = Op0->getIntrinsicID(); + if (ID0 != Intrinsic::hexagon_V6_vandqrt && + ID0 != Intrinsic::hexagon_V6_vandqrt_128B) + break; + Value *Bytes = Op0->getArgOperand(1), *Mask = II->getArgOperand(1); + uint64_t Bytes1 = computeKnownBits(Bytes, 0, Op0).One.getZExtValue(); + uint64_t Mask1 = computeKnownBits(Mask, 0, II).One.getZExtValue(); + // Check if every byte has common bits in Bytes and Mask. + uint64_t C = Bytes1 & Mask1; + if ((C & 0xFF) && (C & 0xFF00) && (C & 0xFF0000) && (C & 0xFF000000)) + return replaceInstUsesWith(*II, Op0->getArgOperand(0)); + } + break; + } + case Intrinsic::stackrestore: { + enum class ClassifyResult { + None, + Alloca, + StackRestore, + CallWithSideEffects, + }; + auto Classify = [](const Instruction *I) { + if (isa<AllocaInst>(I)) + return ClassifyResult::Alloca; + + if (auto *CI = dyn_cast<CallInst>(I)) { + if (auto *II = dyn_cast<IntrinsicInst>(CI)) { + if (II->getIntrinsicID() == Intrinsic::stackrestore) + return ClassifyResult::StackRestore; + + if (II->mayHaveSideEffects()) + return ClassifyResult::CallWithSideEffects; + } else { + // Consider all non-intrinsic calls to be side effects + return ClassifyResult::CallWithSideEffects; + } + } + + return ClassifyResult::None; + }; + + // If the stacksave and the stackrestore are in the same BB, and there is + // no intervening call, alloca, or stackrestore of a different stacksave, + // remove the restore. This can happen when variable allocas are DCE'd. + if (IntrinsicInst *SS = dyn_cast<IntrinsicInst>(II->getArgOperand(0))) { + if (SS->getIntrinsicID() == Intrinsic::stacksave && + SS->getParent() == II->getParent()) { + BasicBlock::iterator BI(SS); + bool CannotRemove = false; + for (++BI; &*BI != II; ++BI) { + switch (Classify(&*BI)) { + case ClassifyResult::None: + // So far so good, look at next instructions. + break; + + case ClassifyResult::StackRestore: + // If we found an intervening stackrestore for a different + // stacksave, we can't remove the stackrestore. Otherwise, continue. + if (cast<IntrinsicInst>(*BI).getArgOperand(0) != SS) + CannotRemove = true; + break; + + case ClassifyResult::Alloca: + case ClassifyResult::CallWithSideEffects: + // If we found an alloca, a non-intrinsic call, or an intrinsic + // call with side effects, we can't remove the stackrestore. + CannotRemove = true; + break; + } + if (CannotRemove) + break; + } + + if (!CannotRemove) + return eraseInstFromFunction(CI); + } + } + + // Scan down this block to see if there is another stack restore in the + // same block without an intervening call/alloca. + BasicBlock::iterator BI(II); + Instruction *TI = II->getParent()->getTerminator(); + bool CannotRemove = false; + for (++BI; &*BI != TI; ++BI) { + switch (Classify(&*BI)) { + case ClassifyResult::None: + // So far so good, look at next instructions. + break; + + case ClassifyResult::StackRestore: + // If there is a stackrestore below this one, remove this one. + return eraseInstFromFunction(CI); + + case ClassifyResult::Alloca: + case ClassifyResult::CallWithSideEffects: + // If we found an alloca, a non-intrinsic call, or an intrinsic call + // with side effects (such as llvm.stacksave and llvm.read_register), + // we can't remove the stack restore. + CannotRemove = true; + break; + } + if (CannotRemove) + break; + } + + // If the stack restore is in a return, resume, or unwind block and if there + // are no allocas or calls between the restore and the return, nuke the + // restore. + if (!CannotRemove && (isa<ReturnInst>(TI) || isa<ResumeInst>(TI))) + return eraseInstFromFunction(CI); + break; + } + case Intrinsic::lifetime_end: + // Asan needs to poison memory to detect invalid access which is possible + // even for empty lifetime range. + if (II->getFunction()->hasFnAttribute(Attribute::SanitizeAddress) || + II->getFunction()->hasFnAttribute(Attribute::SanitizeMemory) || + II->getFunction()->hasFnAttribute(Attribute::SanitizeHWAddress)) + break; + + if (removeTriviallyEmptyRange(*II, *this, [](const IntrinsicInst &I) { + return I.getIntrinsicID() == Intrinsic::lifetime_start; + })) + return nullptr; + break; + case Intrinsic::assume: { + Value *IIOperand = II->getArgOperand(0); + SmallVector<OperandBundleDef, 4> OpBundles; + II->getOperandBundlesAsDefs(OpBundles); + + /// This will remove the boolean Condition from the assume given as + /// argument and remove the assume if it becomes useless. + /// always returns nullptr for use as a return values. + auto RemoveConditionFromAssume = [&](Instruction *Assume) -> Instruction * { + assert(isa<AssumeInst>(Assume)); + if (isAssumeWithEmptyBundle(*cast<AssumeInst>(II))) + return eraseInstFromFunction(CI); + replaceUse(II->getOperandUse(0), ConstantInt::getTrue(II->getContext())); + return nullptr; + }; + // Remove an assume if it is followed by an identical assume. + // TODO: Do we need this? Unless there are conflicting assumptions, the + // computeKnownBits(IIOperand) below here eliminates redundant assumes. + Instruction *Next = II->getNextNonDebugInstruction(); + if (match(Next, m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand)))) + return RemoveConditionFromAssume(Next); + + // Canonicalize assume(a && b) -> assume(a); assume(b); + // Note: New assumption intrinsics created here are registered by + // the InstCombineIRInserter object. + FunctionType *AssumeIntrinsicTy = II->getFunctionType(); + Value *AssumeIntrinsic = II->getCalledOperand(); + Value *A, *B; + if (match(IIOperand, m_LogicalAnd(m_Value(A), m_Value(B)))) { + Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, A, OpBundles, + II->getName()); + Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, B, II->getName()); + return eraseInstFromFunction(*II); + } + // assume(!(a || b)) -> assume(!a); assume(!b); + if (match(IIOperand, m_Not(m_LogicalOr(m_Value(A), m_Value(B))))) { + Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, + Builder.CreateNot(A), OpBundles, II->getName()); + Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, + Builder.CreateNot(B), II->getName()); + return eraseInstFromFunction(*II); + } + + // assume( (load addr) != null ) -> add 'nonnull' metadata to load + // (if assume is valid at the load) + CmpInst::Predicate Pred; + Instruction *LHS; + if (match(IIOperand, m_ICmp(Pred, m_Instruction(LHS), m_Zero())) && + Pred == ICmpInst::ICMP_NE && LHS->getOpcode() == Instruction::Load && + LHS->getType()->isPointerTy() && + isValidAssumeForContext(II, LHS, &DT)) { + MDNode *MD = MDNode::get(II->getContext(), None); + LHS->setMetadata(LLVMContext::MD_nonnull, MD); + return RemoveConditionFromAssume(II); + + // TODO: apply nonnull return attributes to calls and invokes + // TODO: apply range metadata for range check patterns? + } + + // Convert nonnull assume like: + // %A = icmp ne i32* %PTR, null + // call void @llvm.assume(i1 %A) + // into + // call void @llvm.assume(i1 true) [ "nonnull"(i32* %PTR) ] + if (EnableKnowledgeRetention && + match(IIOperand, m_Cmp(Pred, m_Value(A), m_Zero())) && + Pred == CmpInst::ICMP_NE && A->getType()->isPointerTy()) { + if (auto *Replacement = buildAssumeFromKnowledge( + {RetainedKnowledge{Attribute::NonNull, 0, A}}, Next, &AC, &DT)) { + + Replacement->insertBefore(Next); + AC.registerAssumption(Replacement); + return RemoveConditionFromAssume(II); + } + } + + // Convert alignment assume like: + // %B = ptrtoint i32* %A to i64 + // %C = and i64 %B, Constant + // %D = icmp eq i64 %C, 0 + // call void @llvm.assume(i1 %D) + // into + // call void @llvm.assume(i1 true) [ "align"(i32* [[A]], i64 Constant + 1)] + uint64_t AlignMask; + if (EnableKnowledgeRetention && + match(IIOperand, + m_Cmp(Pred, m_And(m_Value(A), m_ConstantInt(AlignMask)), + m_Zero())) && + Pred == CmpInst::ICMP_EQ) { + if (isPowerOf2_64(AlignMask + 1)) { + uint64_t Offset = 0; + match(A, m_Add(m_Value(A), m_ConstantInt(Offset))); + if (match(A, m_PtrToInt(m_Value(A)))) { + /// Note: this doesn't preserve the offset information but merges + /// offset and alignment. + /// TODO: we can generate a GEP instead of merging the alignment with + /// the offset. + RetainedKnowledge RK{Attribute::Alignment, + (unsigned)MinAlign(Offset, AlignMask + 1), A}; + if (auto *Replacement = + buildAssumeFromKnowledge(RK, Next, &AC, &DT)) { + + Replacement->insertAfter(II); + AC.registerAssumption(Replacement); + } + return RemoveConditionFromAssume(II); + } + } + } + + /// Canonicalize Knowledge in operand bundles. + if (EnableKnowledgeRetention && II->hasOperandBundles()) { + for (unsigned Idx = 0; Idx < II->getNumOperandBundles(); Idx++) { + auto &BOI = II->bundle_op_info_begin()[Idx]; + RetainedKnowledge RK = + llvm::getKnowledgeFromBundle(cast<AssumeInst>(*II), BOI); + if (BOI.End - BOI.Begin > 2) + continue; // Prevent reducing knowledge in an align with offset since + // extracting a RetainedKnowledge form them looses offset + // information + RetainedKnowledge CanonRK = + llvm::simplifyRetainedKnowledge(cast<AssumeInst>(II), RK, + &getAssumptionCache(), + &getDominatorTree()); + if (CanonRK == RK) + continue; + if (!CanonRK) { + if (BOI.End - BOI.Begin > 0) { + Worklist.pushValue(II->op_begin()[BOI.Begin]); + Value::dropDroppableUse(II->op_begin()[BOI.Begin]); + } + continue; + } + assert(RK.AttrKind == CanonRK.AttrKind); + if (BOI.End - BOI.Begin > 0) + II->op_begin()[BOI.Begin].set(CanonRK.WasOn); + if (BOI.End - BOI.Begin > 1) + II->op_begin()[BOI.Begin + 1].set(ConstantInt::get( + Type::getInt64Ty(II->getContext()), CanonRK.ArgValue)); + if (RK.WasOn) + Worklist.pushValue(RK.WasOn); + return II; + } + } + + // If there is a dominating assume with the same condition as this one, + // then this one is redundant, and should be removed. + KnownBits Known(1); + computeKnownBits(IIOperand, Known, 0, II); + if (Known.isAllOnes() && isAssumeWithEmptyBundle(cast<AssumeInst>(*II))) + return eraseInstFromFunction(*II); + + // Update the cache of affected values for this assumption (we might be + // here because we just simplified the condition). + AC.updateAffectedValues(cast<AssumeInst>(II)); + break; + } + case Intrinsic::experimental_guard: { + // Is this guard followed by another guard? We scan forward over a small + // fixed window of instructions to handle common cases with conditions + // computed between guards. + Instruction *NextInst = II->getNextNonDebugInstruction(); + for (unsigned i = 0; i < GuardWideningWindow; i++) { + // Note: Using context-free form to avoid compile time blow up + if (!isSafeToSpeculativelyExecute(NextInst)) + break; + NextInst = NextInst->getNextNonDebugInstruction(); + } + Value *NextCond = nullptr; + if (match(NextInst, + m_Intrinsic<Intrinsic::experimental_guard>(m_Value(NextCond)))) { + Value *CurrCond = II->getArgOperand(0); + + // Remove a guard that it is immediately preceded by an identical guard. + // Otherwise canonicalize guard(a); guard(b) -> guard(a & b). + if (CurrCond != NextCond) { + Instruction *MoveI = II->getNextNonDebugInstruction(); + while (MoveI != NextInst) { + auto *Temp = MoveI; + MoveI = MoveI->getNextNonDebugInstruction(); + Temp->moveBefore(II); + } + replaceOperand(*II, 0, Builder.CreateAnd(CurrCond, NextCond)); + } + eraseInstFromFunction(*NextInst); + return II; + } + break; + } + case Intrinsic::vector_insert: { + Value *Vec = II->getArgOperand(0); + Value *SubVec = II->getArgOperand(1); + Value *Idx = II->getArgOperand(2); + auto *DstTy = dyn_cast<FixedVectorType>(II->getType()); + auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType()); + auto *SubVecTy = dyn_cast<FixedVectorType>(SubVec->getType()); + + // Only canonicalize if the destination vector, Vec, and SubVec are all + // fixed vectors. + if (DstTy && VecTy && SubVecTy) { + unsigned DstNumElts = DstTy->getNumElements(); + unsigned VecNumElts = VecTy->getNumElements(); + unsigned SubVecNumElts = SubVecTy->getNumElements(); + unsigned IdxN = cast<ConstantInt>(Idx)->getZExtValue(); + + // An insert that entirely overwrites Vec with SubVec is a nop. + if (VecNumElts == SubVecNumElts) + return replaceInstUsesWith(CI, SubVec); + + // Widen SubVec into a vector of the same width as Vec, since + // shufflevector requires the two input vectors to be the same width. + // Elements beyond the bounds of SubVec within the widened vector are + // undefined. + SmallVector<int, 8> WidenMask; + unsigned i; + for (i = 0; i != SubVecNumElts; ++i) + WidenMask.push_back(i); + for (; i != VecNumElts; ++i) + WidenMask.push_back(UndefMaskElem); + + Value *WidenShuffle = Builder.CreateShuffleVector(SubVec, WidenMask); + + SmallVector<int, 8> Mask; + for (unsigned i = 0; i != IdxN; ++i) + Mask.push_back(i); + for (unsigned i = DstNumElts; i != DstNumElts + SubVecNumElts; ++i) + Mask.push_back(i); + for (unsigned i = IdxN + SubVecNumElts; i != DstNumElts; ++i) + Mask.push_back(i); + + Value *Shuffle = Builder.CreateShuffleVector(Vec, WidenShuffle, Mask); + return replaceInstUsesWith(CI, Shuffle); + } + break; + } + case Intrinsic::vector_extract: { + Value *Vec = II->getArgOperand(0); + Value *Idx = II->getArgOperand(1); + + auto *DstTy = dyn_cast<FixedVectorType>(II->getType()); + auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType()); + + // Only canonicalize if the the destination vector and Vec are fixed + // vectors. + if (DstTy && VecTy) { + unsigned DstNumElts = DstTy->getNumElements(); + unsigned VecNumElts = VecTy->getNumElements(); + unsigned IdxN = cast<ConstantInt>(Idx)->getZExtValue(); + + // Extracting the entirety of Vec is a nop. + if (VecNumElts == DstNumElts) { + replaceInstUsesWith(CI, Vec); + return eraseInstFromFunction(CI); + } + + SmallVector<int, 8> Mask; + for (unsigned i = 0; i != DstNumElts; ++i) + Mask.push_back(IdxN + i); + + Value *Shuffle = Builder.CreateShuffleVector(Vec, Mask); + return replaceInstUsesWith(CI, Shuffle); + } + break; + } + case Intrinsic::experimental_vector_reverse: { + Value *BO0, *BO1, *X, *Y; + Value *Vec = II->getArgOperand(0); + if (match(Vec, m_OneUse(m_BinOp(m_Value(BO0), m_Value(BO1))))) { + auto *OldBinOp = cast<BinaryOperator>(Vec); + if (match(BO0, m_Intrinsic<Intrinsic::experimental_vector_reverse>( + m_Value(X)))) { + // rev(binop rev(X), rev(Y)) --> binop X, Y + if (match(BO1, m_Intrinsic<Intrinsic::experimental_vector_reverse>( + m_Value(Y)))) + return replaceInstUsesWith(CI, + BinaryOperator::CreateWithCopiedFlags( + OldBinOp->getOpcode(), X, Y, OldBinOp, + OldBinOp->getName(), II)); + // rev(binop rev(X), BO1Splat) --> binop X, BO1Splat + if (isSplatValue(BO1)) + return replaceInstUsesWith(CI, + BinaryOperator::CreateWithCopiedFlags( + OldBinOp->getOpcode(), X, BO1, + OldBinOp, OldBinOp->getName(), II)); + } + // rev(binop BO0Splat, rev(Y)) --> binop BO0Splat, Y + if (match(BO1, m_Intrinsic<Intrinsic::experimental_vector_reverse>( + m_Value(Y))) && + isSplatValue(BO0)) + return replaceInstUsesWith(CI, BinaryOperator::CreateWithCopiedFlags( + OldBinOp->getOpcode(), BO0, Y, + OldBinOp, OldBinOp->getName(), II)); + } + // rev(unop rev(X)) --> unop X + if (match(Vec, m_OneUse(m_UnOp( + m_Intrinsic<Intrinsic::experimental_vector_reverse>( + m_Value(X)))))) { + auto *OldUnOp = cast<UnaryOperator>(Vec); + auto *NewUnOp = UnaryOperator::CreateWithCopiedFlags( + OldUnOp->getOpcode(), X, OldUnOp, OldUnOp->getName(), II); + return replaceInstUsesWith(CI, NewUnOp); + } + break; + } + case Intrinsic::vector_reduce_or: + case Intrinsic::vector_reduce_and: { + // Canonicalize logical or/and reductions: + // Or reduction for i1 is represented as: + // %val = bitcast <ReduxWidth x i1> to iReduxWidth + // %res = cmp ne iReduxWidth %val, 0 + // And reduction for i1 is represented as: + // %val = bitcast <ReduxWidth x i1> to iReduxWidth + // %res = cmp eq iReduxWidth %val, 11111 + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *Res = Builder.CreateBitCast( + Vect, Builder.getIntNTy(FTy->getNumElements())); + if (IID == Intrinsic::vector_reduce_and) { + Res = Builder.CreateICmpEQ( + Res, ConstantInt::getAllOnesValue(Res->getType())); + } else { + assert(IID == Intrinsic::vector_reduce_or && + "Expected or reduction."); + Res = Builder.CreateIsNotNull(Res); + } + if (Arg != Vect) + Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res, + II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + LLVM_FALLTHROUGH; + } + case Intrinsic::vector_reduce_add: { + if (IID == Intrinsic::vector_reduce_add) { + // Convert vector_reduce_add(ZExt(<n x i1>)) to + // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)). + // Convert vector_reduce_add(SExt(<n x i1>)) to + // -ZExtOrTrunc(ctpop(bitcast <n x i1> to in)). + // Convert vector_reduce_add(<n x i1>) to + // Trunc(ctpop(bitcast <n x i1> to in)). + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *V = Builder.CreateBitCast( + Vect, Builder.getIntNTy(FTy->getNumElements())); + Value *Res = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V); + if (Res->getType() != II->getType()) + Res = Builder.CreateZExtOrTrunc(Res, II->getType()); + if (Arg != Vect && + cast<Instruction>(Arg)->getOpcode() == Instruction::SExt) + Res = Builder.CreateNeg(Res); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } + case Intrinsic::vector_reduce_xor: { + if (IID == Intrinsic::vector_reduce_xor) { + // Exclusive disjunction reduction over the vector with + // (potentially-extended) i1 element type is actually a + // (potentially-extended) arithmetic `add` reduction over the original + // non-extended value: + // vector_reduce_xor(?ext(<n x i1>)) + // --> + // ?ext(vector_reduce_add(<n x i1>)) + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *Res = Builder.CreateAddReduce(Vect); + if (Arg != Vect) + Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res, + II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } + case Intrinsic::vector_reduce_mul: { + if (IID == Intrinsic::vector_reduce_mul) { + // Multiplicative reduction over the vector with (potentially-extended) + // i1 element type is actually a (potentially zero-extended) + // logical `and` reduction over the original non-extended value: + // vector_reduce_mul(?ext(<n x i1>)) + // --> + // zext(vector_reduce_and(<n x i1>)) + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *Res = Builder.CreateAndReduce(Vect); + if (Res->getType() != II->getType()) + Res = Builder.CreateZExt(Res, II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } + case Intrinsic::vector_reduce_umin: + case Intrinsic::vector_reduce_umax: { + if (IID == Intrinsic::vector_reduce_umin || + IID == Intrinsic::vector_reduce_umax) { + // UMin/UMax reduction over the vector with (potentially-extended) + // i1 element type is actually a (potentially-extended) + // logical `and`/`or` reduction over the original non-extended value: + // vector_reduce_u{min,max}(?ext(<n x i1>)) + // --> + // ?ext(vector_reduce_{and,or}(<n x i1>)) + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *Res = IID == Intrinsic::vector_reduce_umin + ? Builder.CreateAndReduce(Vect) + : Builder.CreateOrReduce(Vect); + if (Arg != Vect) + Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res, + II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } + case Intrinsic::vector_reduce_smin: + case Intrinsic::vector_reduce_smax: { + if (IID == Intrinsic::vector_reduce_smin || + IID == Intrinsic::vector_reduce_smax) { + // SMin/SMax reduction over the vector with (potentially-extended) + // i1 element type is actually a (potentially-extended) + // logical `and`/`or` reduction over the original non-extended value: + // vector_reduce_s{min,max}(<n x i1>) + // --> + // vector_reduce_{or,and}(<n x i1>) + // and + // vector_reduce_s{min,max}(sext(<n x i1>)) + // --> + // sext(vector_reduce_{or,and}(<n x i1>)) + // and + // vector_reduce_s{min,max}(zext(<n x i1>)) + // --> + // zext(vector_reduce_{and,or}(<n x i1>)) + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Instruction::CastOps ExtOpc = Instruction::CastOps::CastOpsEnd; + if (Arg != Vect) + ExtOpc = cast<CastInst>(Arg)->getOpcode(); + Value *Res = ((IID == Intrinsic::vector_reduce_smin) == + (ExtOpc == Instruction::CastOps::ZExt)) + ? Builder.CreateAndReduce(Vect) + : Builder.CreateOrReduce(Vect); + if (Arg != Vect) + Res = Builder.CreateCast(ExtOpc, Res, II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } + case Intrinsic::vector_reduce_fmax: + case Intrinsic::vector_reduce_fmin: + case Intrinsic::vector_reduce_fadd: + case Intrinsic::vector_reduce_fmul: { + bool CanBeReassociated = (IID != Intrinsic::vector_reduce_fadd && + IID != Intrinsic::vector_reduce_fmul) || + II->hasAllowReassoc(); + const unsigned ArgIdx = (IID == Intrinsic::vector_reduce_fadd || + IID == Intrinsic::vector_reduce_fmul) + ? 1 + : 0; + Value *Arg = II->getArgOperand(ArgIdx); + Value *V; + ArrayRef<int> Mask; + if (!isa<FixedVectorType>(Arg->getType()) || !CanBeReassociated || + !match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))) || + !cast<ShuffleVectorInst>(Arg)->isSingleSource()) + break; + int Sz = Mask.size(); + SmallBitVector UsedIndices(Sz); + for (int Idx : Mask) { + if (Idx == UndefMaskElem || UsedIndices.test(Idx)) + break; + UsedIndices.set(Idx); + } + // Can remove shuffle iff just shuffled elements, no repeats, undefs, or + // other changes. + if (UsedIndices.all()) { + replaceUse(II->getOperandUse(ArgIdx), V); + return nullptr; + } + break; + } + default: { + // Handle target specific intrinsics + Optional<Instruction *> V = targetInstCombineIntrinsic(*II); + if (V) + return V.getValue(); + break; + } + } + + if (Instruction *Shuf = foldShuffledIntrinsicOperands(II, Builder)) + return Shuf; + + // Some intrinsics (like experimental_gc_statepoint) can be used in invoke + // context, so it is handled in visitCallBase and we should trigger it. + return visitCallBase(*II); +} + +// Fence instruction simplification +Instruction *InstCombinerImpl::visitFenceInst(FenceInst &FI) { + auto *NFI = dyn_cast<FenceInst>(FI.getNextNonDebugInstruction()); + // This check is solely here to handle arbitrary target-dependent syncscopes. + // TODO: Can remove if does not matter in practice. + if (NFI && FI.isIdenticalTo(NFI)) + return eraseInstFromFunction(FI); + + // Returns true if FI1 is identical or stronger fence than FI2. + auto isIdenticalOrStrongerFence = [](FenceInst *FI1, FenceInst *FI2) { + auto FI1SyncScope = FI1->getSyncScopeID(); + // Consider same scope, where scope is global or single-thread. + if (FI1SyncScope != FI2->getSyncScopeID() || + (FI1SyncScope != SyncScope::System && + FI1SyncScope != SyncScope::SingleThread)) + return false; + + return isAtLeastOrStrongerThan(FI1->getOrdering(), FI2->getOrdering()); + }; + if (NFI && isIdenticalOrStrongerFence(NFI, &FI)) + return eraseInstFromFunction(FI); + + if (auto *PFI = dyn_cast_or_null<FenceInst>(FI.getPrevNonDebugInstruction())) + if (isIdenticalOrStrongerFence(PFI, &FI)) + return eraseInstFromFunction(FI); + return nullptr; +} + +// InvokeInst simplification +Instruction *InstCombinerImpl::visitInvokeInst(InvokeInst &II) { + return visitCallBase(II); +} + +// CallBrInst simplification +Instruction *InstCombinerImpl::visitCallBrInst(CallBrInst &CBI) { + return visitCallBase(CBI); +} + +/// If this cast does not affect the value passed through the varargs area, we +/// can eliminate the use of the cast. +static bool isSafeToEliminateVarargsCast(const CallBase &Call, + const DataLayout &DL, + const CastInst *const CI, + const int ix) { + if (!CI->isLosslessCast()) + return false; + + // If this is a GC intrinsic, avoid munging types. We need types for + // statepoint reconstruction in SelectionDAG. + // TODO: This is probably something which should be expanded to all + // intrinsics since the entire point of intrinsics is that + // they are understandable by the optimizer. + if (isa<GCStatepointInst>(Call) || isa<GCRelocateInst>(Call) || + isa<GCResultInst>(Call)) + return false; + + // Opaque pointers are compatible with any byval types. + PointerType *SrcTy = cast<PointerType>(CI->getOperand(0)->getType()); + if (SrcTy->isOpaque()) + return true; + + // The size of ByVal or InAlloca arguments is derived from the type, so we + // can't change to a type with a different size. If the size were + // passed explicitly we could avoid this check. + if (!Call.isPassPointeeByValueArgument(ix)) + return true; + + // The transform currently only handles type replacement for byval, not other + // type-carrying attributes. + if (!Call.isByValArgument(ix)) + return false; + + Type *SrcElemTy = SrcTy->getNonOpaquePointerElementType(); + Type *DstElemTy = Call.getParamByValType(ix); + if (!SrcElemTy->isSized() || !DstElemTy->isSized()) + return false; + if (DL.getTypeAllocSize(SrcElemTy) != DL.getTypeAllocSize(DstElemTy)) + return false; + return true; +} + +Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) { + if (!CI->getCalledFunction()) return nullptr; + + // Skip optimizing notail and musttail calls so + // LibCallSimplifier::optimizeCall doesn't have to preserve those invariants. + // LibCallSimplifier::optimizeCall should try to preseve tail calls though. + if (CI->isMustTailCall() || CI->isNoTailCall()) + return nullptr; + + auto InstCombineRAUW = [this](Instruction *From, Value *With) { + replaceInstUsesWith(*From, With); + }; + auto InstCombineErase = [this](Instruction *I) { + eraseInstFromFunction(*I); + }; + LibCallSimplifier Simplifier(DL, &TLI, ORE, BFI, PSI, InstCombineRAUW, + InstCombineErase); + if (Value *With = Simplifier.optimizeCall(CI, Builder)) { + ++NumSimplified; + return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With); + } + + return nullptr; +} + +static IntrinsicInst *findInitTrampolineFromAlloca(Value *TrampMem) { + // Strip off at most one level of pointer casts, looking for an alloca. This + // is good enough in practice and simpler than handling any number of casts. + Value *Underlying = TrampMem->stripPointerCasts(); + if (Underlying != TrampMem && + (!Underlying->hasOneUse() || Underlying->user_back() != TrampMem)) + return nullptr; + if (!isa<AllocaInst>(Underlying)) + return nullptr; + + IntrinsicInst *InitTrampoline = nullptr; + for (User *U : TrampMem->users()) { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); + if (!II) + return nullptr; + if (II->getIntrinsicID() == Intrinsic::init_trampoline) { + if (InitTrampoline) + // More than one init_trampoline writes to this value. Give up. + return nullptr; + InitTrampoline = II; + continue; + } + if (II->getIntrinsicID() == Intrinsic::adjust_trampoline) + // Allow any number of calls to adjust.trampoline. + continue; + return nullptr; + } + + // No call to init.trampoline found. + if (!InitTrampoline) + return nullptr; + + // Check that the alloca is being used in the expected way. + if (InitTrampoline->getOperand(0) != TrampMem) + return nullptr; + + return InitTrampoline; +} + +static IntrinsicInst *findInitTrampolineFromBB(IntrinsicInst *AdjustTramp, + Value *TrampMem) { + // Visit all the previous instructions in the basic block, and try to find a + // init.trampoline which has a direct path to the adjust.trampoline. + for (BasicBlock::iterator I = AdjustTramp->getIterator(), + E = AdjustTramp->getParent()->begin(); + I != E;) { + Instruction *Inst = &*--I; + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) + if (II->getIntrinsicID() == Intrinsic::init_trampoline && + II->getOperand(0) == TrampMem) + return II; + if (Inst->mayWriteToMemory()) + return nullptr; + } + return nullptr; +} + +// Given a call to llvm.adjust.trampoline, find and return the corresponding +// call to llvm.init.trampoline if the call to the trampoline can be optimized +// to a direct call to a function. Otherwise return NULL. +static IntrinsicInst *findInitTrampoline(Value *Callee) { + Callee = Callee->stripPointerCasts(); + IntrinsicInst *AdjustTramp = dyn_cast<IntrinsicInst>(Callee); + if (!AdjustTramp || + AdjustTramp->getIntrinsicID() != Intrinsic::adjust_trampoline) + return nullptr; + + Value *TrampMem = AdjustTramp->getOperand(0); + + if (IntrinsicInst *IT = findInitTrampolineFromAlloca(TrampMem)) + return IT; + if (IntrinsicInst *IT = findInitTrampolineFromBB(AdjustTramp, TrampMem)) + return IT; + return nullptr; +} + +bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call, + const TargetLibraryInfo *TLI) { + // Note: We only handle cases which can't be driven from generic attributes + // here. So, for example, nonnull and noalias (which are common properties + // of some allocation functions) are expected to be handled via annotation + // of the respective allocator declaration with generic attributes. + bool Changed = false; + + if (isAllocationFn(&Call, TLI)) { + uint64_t Size; + ObjectSizeOpts Opts; + if (getObjectSize(&Call, Size, DL, TLI, Opts) && Size > 0) { + // TODO: We really should just emit deref_or_null here and then + // let the generic inference code combine that with nonnull. + if (Call.hasRetAttr(Attribute::NonNull)) { + Changed = !Call.hasRetAttr(Attribute::Dereferenceable); + Call.addRetAttr( + Attribute::getWithDereferenceableBytes(Call.getContext(), Size)); + } else { + Changed = !Call.hasRetAttr(Attribute::DereferenceableOrNull); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Size)); + } + } + } + + // Add alignment attribute if alignment is a power of two constant. + Value *Alignment = getAllocAlignment(&Call, TLI); + if (!Alignment) + return Changed; + + ConstantInt *AlignOpC = dyn_cast<ConstantInt>(Alignment); + if (AlignOpC && AlignOpC->getValue().ult(llvm::Value::MaximumAlignment)) { + uint64_t AlignmentVal = AlignOpC->getZExtValue(); + if (llvm::isPowerOf2_64(AlignmentVal)) { + Align ExistingAlign = Call.getRetAlign().valueOrOne(); + Align NewAlign = Align(AlignmentVal); + if (NewAlign > ExistingAlign) { + Call.addRetAttr( + Attribute::getWithAlignment(Call.getContext(), NewAlign)); + Changed = true; + } + } + } + return Changed; +} + +/// Improvements for call, callbr and invoke instructions. +Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { + bool Changed = annotateAnyAllocSite(Call, &TLI); + + // Mark any parameters that are known to be non-null with the nonnull + // attribute. This is helpful for inlining calls to functions with null + // checks on their arguments. + SmallVector<unsigned, 4> ArgNos; + unsigned ArgNo = 0; + + for (Value *V : Call.args()) { + if (V->getType()->isPointerTy() && + !Call.paramHasAttr(ArgNo, Attribute::NonNull) && + isKnownNonZero(V, DL, 0, &AC, &Call, &DT)) + ArgNos.push_back(ArgNo); + ArgNo++; + } + + assert(ArgNo == Call.arg_size() && "Call arguments not processed correctly."); + + if (!ArgNos.empty()) { + AttributeList AS = Call.getAttributes(); + LLVMContext &Ctx = Call.getContext(); + AS = AS.addParamAttribute(Ctx, ArgNos, + Attribute::get(Ctx, Attribute::NonNull)); + Call.setAttributes(AS); + Changed = true; + } + + // If the callee is a pointer to a function, attempt to move any casts to the + // arguments of the call/callbr/invoke. + Value *Callee = Call.getCalledOperand(); + Function *CalleeF = dyn_cast<Function>(Callee); + if ((!CalleeF || CalleeF->getFunctionType() != Call.getFunctionType()) && + transformConstExprCastCall(Call)) + return nullptr; + + if (CalleeF) { + // Remove the convergent attr on calls when the callee is not convergent. + if (Call.isConvergent() && !CalleeF->isConvergent() && + !CalleeF->isIntrinsic()) { + LLVM_DEBUG(dbgs() << "Removing convergent attr from instr " << Call + << "\n"); + Call.setNotConvergent(); + return &Call; + } + + // If the call and callee calling conventions don't match, and neither one + // of the calling conventions is compatible with C calling convention + // this call must be unreachable, as the call is undefined. + if ((CalleeF->getCallingConv() != Call.getCallingConv() && + !(CalleeF->getCallingConv() == llvm::CallingConv::C && + TargetLibraryInfoImpl::isCallingConvCCompatible(&Call)) && + !(Call.getCallingConv() == llvm::CallingConv::C && + TargetLibraryInfoImpl::isCallingConvCCompatible(CalleeF))) && + // Only do this for calls to a function with a body. A prototype may + // not actually end up matching the implementation's calling conv for a + // variety of reasons (e.g. it may be written in assembly). + !CalleeF->isDeclaration()) { + Instruction *OldCall = &Call; + CreateNonTerminatorUnreachable(OldCall); + // If OldCall does not return void then replaceInstUsesWith poison. + // This allows ValueHandlers and custom metadata to adjust itself. + if (!OldCall->getType()->isVoidTy()) + replaceInstUsesWith(*OldCall, PoisonValue::get(OldCall->getType())); + if (isa<CallInst>(OldCall)) + return eraseInstFromFunction(*OldCall); + + // We cannot remove an invoke or a callbr, because it would change thexi + // CFG, just change the callee to a null pointer. + cast<CallBase>(OldCall)->setCalledFunction( + CalleeF->getFunctionType(), + Constant::getNullValue(CalleeF->getType())); + return nullptr; + } + } + + // Calling a null function pointer is undefined if a null address isn't + // dereferenceable. + if ((isa<ConstantPointerNull>(Callee) && + !NullPointerIsDefined(Call.getFunction())) || + isa<UndefValue>(Callee)) { + // If Call does not return void then replaceInstUsesWith poison. + // This allows ValueHandlers and custom metadata to adjust itself. + if (!Call.getType()->isVoidTy()) + replaceInstUsesWith(Call, PoisonValue::get(Call.getType())); + + if (Call.isTerminator()) { + // Can't remove an invoke or callbr because we cannot change the CFG. + return nullptr; + } + + // This instruction is not reachable, just remove it. + CreateNonTerminatorUnreachable(&Call); + return eraseInstFromFunction(Call); + } + + if (IntrinsicInst *II = findInitTrampoline(Callee)) + return transformCallThroughTrampoline(Call, *II); + + // TODO: Drop this transform once opaque pointer transition is done. + FunctionType *FTy = Call.getFunctionType(); + if (FTy->isVarArg()) { + int ix = FTy->getNumParams(); + // See if we can optimize any arguments passed through the varargs area of + // the call. + for (auto I = Call.arg_begin() + FTy->getNumParams(), E = Call.arg_end(); + I != E; ++I, ++ix) { + CastInst *CI = dyn_cast<CastInst>(*I); + if (CI && isSafeToEliminateVarargsCast(Call, DL, CI, ix)) { + replaceUse(*I, CI->getOperand(0)); + + // Update the byval type to match the pointer type. + // Not necessary for opaque pointers. + PointerType *NewTy = cast<PointerType>(CI->getOperand(0)->getType()); + if (!NewTy->isOpaque() && Call.isByValArgument(ix)) { + Call.removeParamAttr(ix, Attribute::ByVal); + Call.addParamAttr(ix, Attribute::getWithByValType( + Call.getContext(), + NewTy->getNonOpaquePointerElementType())); + } + Changed = true; + } + } + } + + if (isa<InlineAsm>(Callee) && !Call.doesNotThrow()) { + InlineAsm *IA = cast<InlineAsm>(Callee); + if (!IA->canThrow()) { + // Normal inline asm calls cannot throw - mark them + // 'nounwind'. + Call.setDoesNotThrow(); + Changed = true; + } + } + + // Try to optimize the call if possible, we require DataLayout for most of + // this. None of these calls are seen as possibly dead so go ahead and + // delete the instruction now. + if (CallInst *CI = dyn_cast<CallInst>(&Call)) { + Instruction *I = tryOptimizeCall(CI); + // If we changed something return the result, etc. Otherwise let + // the fallthrough check. + if (I) return eraseInstFromFunction(*I); + } + + if (!Call.use_empty() && !Call.isMustTailCall()) + if (Value *ReturnedArg = Call.getReturnedArgOperand()) { + Type *CallTy = Call.getType(); + Type *RetArgTy = ReturnedArg->getType(); + if (RetArgTy->canLosslesslyBitCastTo(CallTy)) + return replaceInstUsesWith( + Call, Builder.CreateBitOrPointerCast(ReturnedArg, CallTy)); + } + + if (isAllocationFn(&Call, &TLI) && + isAllocRemovable(&cast<CallBase>(Call), &TLI)) + return visitAllocSite(Call); + + // Handle intrinsics which can be used in both call and invoke context. + switch (Call.getIntrinsicID()) { + case Intrinsic::experimental_gc_statepoint: { + GCStatepointInst &GCSP = *cast<GCStatepointInst>(&Call); + SmallPtrSet<Value *, 32> LiveGcValues; + for (const GCRelocateInst *Reloc : GCSP.getGCRelocates()) { + GCRelocateInst &GCR = *const_cast<GCRelocateInst *>(Reloc); + + // Remove the relocation if unused. + if (GCR.use_empty()) { + eraseInstFromFunction(GCR); + continue; + } + + Value *DerivedPtr = GCR.getDerivedPtr(); + Value *BasePtr = GCR.getBasePtr(); + + // Undef is undef, even after relocation. + if (isa<UndefValue>(DerivedPtr) || isa<UndefValue>(BasePtr)) { + replaceInstUsesWith(GCR, UndefValue::get(GCR.getType())); + eraseInstFromFunction(GCR); + continue; + } + + if (auto *PT = dyn_cast<PointerType>(GCR.getType())) { + // The relocation of null will be null for most any collector. + // TODO: provide a hook for this in GCStrategy. There might be some + // weird collector this property does not hold for. + if (isa<ConstantPointerNull>(DerivedPtr)) { + // Use null-pointer of gc_relocate's type to replace it. + replaceInstUsesWith(GCR, ConstantPointerNull::get(PT)); + eraseInstFromFunction(GCR); + continue; + } + + // isKnownNonNull -> nonnull attribute + if (!GCR.hasRetAttr(Attribute::NonNull) && + isKnownNonZero(DerivedPtr, DL, 0, &AC, &Call, &DT)) { + GCR.addRetAttr(Attribute::NonNull); + // We discovered new fact, re-check users. + Worklist.pushUsersToWorkList(GCR); + } + } + + // If we have two copies of the same pointer in the statepoint argument + // list, canonicalize to one. This may let us common gc.relocates. + if (GCR.getBasePtr() == GCR.getDerivedPtr() && + GCR.getBasePtrIndex() != GCR.getDerivedPtrIndex()) { + auto *OpIntTy = GCR.getOperand(2)->getType(); + GCR.setOperand(2, ConstantInt::get(OpIntTy, GCR.getBasePtrIndex())); + } + + // TODO: bitcast(relocate(p)) -> relocate(bitcast(p)) + // Canonicalize on the type from the uses to the defs + + // TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...) + LiveGcValues.insert(BasePtr); + LiveGcValues.insert(DerivedPtr); + } + Optional<OperandBundleUse> Bundle = + GCSP.getOperandBundle(LLVMContext::OB_gc_live); + unsigned NumOfGCLives = LiveGcValues.size(); + if (!Bundle || NumOfGCLives == Bundle->Inputs.size()) + break; + // We can reduce the size of gc live bundle. + DenseMap<Value *, unsigned> Val2Idx; + std::vector<Value *> NewLiveGc; + for (unsigned I = 0, E = Bundle->Inputs.size(); I < E; ++I) { + Value *V = Bundle->Inputs[I]; + if (Val2Idx.count(V)) + continue; + if (LiveGcValues.count(V)) { + Val2Idx[V] = NewLiveGc.size(); + NewLiveGc.push_back(V); + } else + Val2Idx[V] = NumOfGCLives; + } + // Update all gc.relocates + for (const GCRelocateInst *Reloc : GCSP.getGCRelocates()) { + GCRelocateInst &GCR = *const_cast<GCRelocateInst *>(Reloc); + Value *BasePtr = GCR.getBasePtr(); + assert(Val2Idx.count(BasePtr) && Val2Idx[BasePtr] != NumOfGCLives && + "Missed live gc for base pointer"); + auto *OpIntTy1 = GCR.getOperand(1)->getType(); + GCR.setOperand(1, ConstantInt::get(OpIntTy1, Val2Idx[BasePtr])); + Value *DerivedPtr = GCR.getDerivedPtr(); + assert(Val2Idx.count(DerivedPtr) && Val2Idx[DerivedPtr] != NumOfGCLives && + "Missed live gc for derived pointer"); + auto *OpIntTy2 = GCR.getOperand(2)->getType(); + GCR.setOperand(2, ConstantInt::get(OpIntTy2, Val2Idx[DerivedPtr])); + } + // Create new statepoint instruction. + OperandBundleDef NewBundle("gc-live", NewLiveGc); + return CallBase::Create(&Call, NewBundle); + } + default: { break; } + } + + return Changed ? &Call : nullptr; +} + +/// If the callee is a constexpr cast of a function, attempt to move the cast to +/// the arguments of the call/callbr/invoke. +bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { + auto *Callee = + dyn_cast<Function>(Call.getCalledOperand()->stripPointerCasts()); + if (!Callee) + return false; + + // If this is a call to a thunk function, don't remove the cast. Thunks are + // used to transparently forward all incoming parameters and outgoing return + // values, so it's important to leave the cast in place. + if (Callee->hasFnAttribute("thunk")) + return false; + + // If this is a musttail call, the callee's prototype must match the caller's + // prototype with the exception of pointee types. The code below doesn't + // implement that, so we can't do this transform. + // TODO: Do the transform if it only requires adding pointer casts. + if (Call.isMustTailCall()) + return false; + + Instruction *Caller = &Call; + const AttributeList &CallerPAL = Call.getAttributes(); + + // Okay, this is a cast from a function to a different type. Unless doing so + // would cause a type conversion of one of our arguments, change this call to + // be a direct call with arguments casted to the appropriate types. + FunctionType *FT = Callee->getFunctionType(); + Type *OldRetTy = Caller->getType(); + Type *NewRetTy = FT->getReturnType(); + + // Check to see if we are changing the return type... + if (OldRetTy != NewRetTy) { + + if (NewRetTy->isStructTy()) + return false; // TODO: Handle multiple return values. + + if (!CastInst::isBitOrNoopPointerCastable(NewRetTy, OldRetTy, DL)) { + if (Callee->isDeclaration()) + return false; // Cannot transform this return value. + + if (!Caller->use_empty() && + // void -> non-void is handled specially + !NewRetTy->isVoidTy()) + return false; // Cannot transform this return value. + } + + if (!CallerPAL.isEmpty() && !Caller->use_empty()) { + AttrBuilder RAttrs(FT->getContext(), CallerPAL.getRetAttrs()); + if (RAttrs.overlaps(AttributeFuncs::typeIncompatible(NewRetTy))) + return false; // Attribute not compatible with transformed value. + } + + // If the callbase is an invoke/callbr instruction, and the return value is + // used by a PHI node in a successor, we cannot change the return type of + // the call because there is no place to put the cast instruction (without + // breaking the critical edge). Bail out in this case. + if (!Caller->use_empty()) { + if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) + for (User *U : II->users()) + if (PHINode *PN = dyn_cast<PHINode>(U)) + if (PN->getParent() == II->getNormalDest() || + PN->getParent() == II->getUnwindDest()) + return false; + // FIXME: Be conservative for callbr to avoid a quadratic search. + if (isa<CallBrInst>(Caller)) + return false; + } + } + + unsigned NumActualArgs = Call.arg_size(); + unsigned NumCommonArgs = std::min(FT->getNumParams(), NumActualArgs); + + // Prevent us turning: + // declare void @takes_i32_inalloca(i32* inalloca) + // call void bitcast (void (i32*)* @takes_i32_inalloca to void (i32)*)(i32 0) + // + // into: + // call void @takes_i32_inalloca(i32* null) + // + // Similarly, avoid folding away bitcasts of byval calls. + if (Callee->getAttributes().hasAttrSomewhere(Attribute::InAlloca) || + Callee->getAttributes().hasAttrSomewhere(Attribute::Preallocated)) + return false; + + auto AI = Call.arg_begin(); + for (unsigned i = 0, e = NumCommonArgs; i != e; ++i, ++AI) { + Type *ParamTy = FT->getParamType(i); + Type *ActTy = (*AI)->getType(); + + if (!CastInst::isBitOrNoopPointerCastable(ActTy, ParamTy, DL)) + return false; // Cannot transform this parameter value. + + // Check if there are any incompatible attributes we cannot drop safely. + if (AttrBuilder(FT->getContext(), CallerPAL.getParamAttrs(i)) + .overlaps(AttributeFuncs::typeIncompatible( + ParamTy, AttributeFuncs::ASK_UNSAFE_TO_DROP))) + return false; // Attribute not compatible with transformed value. + + if (Call.isInAllocaArgument(i) || + CallerPAL.hasParamAttr(i, Attribute::Preallocated)) + return false; // Cannot transform to and from inalloca/preallocated. + + if (CallerPAL.hasParamAttr(i, Attribute::SwiftError)) + return false; + + // If the parameter is passed as a byval argument, then we have to have a + // sized type and the sized type has to have the same size as the old type. + if (ParamTy != ActTy && CallerPAL.hasParamAttr(i, Attribute::ByVal)) { + PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy); + if (!ParamPTy) + return false; + + if (!ParamPTy->isOpaque()) { + Type *ParamElTy = ParamPTy->getNonOpaquePointerElementType(); + if (!ParamElTy->isSized()) + return false; + + Type *CurElTy = Call.getParamByValType(i); + if (DL.getTypeAllocSize(CurElTy) != DL.getTypeAllocSize(ParamElTy)) + return false; + } + } + } + + if (Callee->isDeclaration()) { + // Do not delete arguments unless we have a function body. + if (FT->getNumParams() < NumActualArgs && !FT->isVarArg()) + return false; + + // If the callee is just a declaration, don't change the varargsness of the + // call. We don't want to introduce a varargs call where one doesn't + // already exist. + if (FT->isVarArg() != Call.getFunctionType()->isVarArg()) + return false; + + // If both the callee and the cast type are varargs, we still have to make + // sure the number of fixed parameters are the same or we have the same + // ABI issues as if we introduce a varargs call. + if (FT->isVarArg() && Call.getFunctionType()->isVarArg() && + FT->getNumParams() != Call.getFunctionType()->getNumParams()) + return false; + } + + if (FT->getNumParams() < NumActualArgs && FT->isVarArg() && + !CallerPAL.isEmpty()) { + // In this case we have more arguments than the new function type, but we + // won't be dropping them. Check that these extra arguments have attributes + // that are compatible with being a vararg call argument. + unsigned SRetIdx; + if (CallerPAL.hasAttrSomewhere(Attribute::StructRet, &SRetIdx) && + SRetIdx - AttributeList::FirstArgIndex >= FT->getNumParams()) + return false; + } + + // Okay, we decided that this is a safe thing to do: go ahead and start + // inserting cast instructions as necessary. + SmallVector<Value *, 8> Args; + SmallVector<AttributeSet, 8> ArgAttrs; + Args.reserve(NumActualArgs); + ArgAttrs.reserve(NumActualArgs); + + // Get any return attributes. + AttrBuilder RAttrs(FT->getContext(), CallerPAL.getRetAttrs()); + + // If the return value is not being used, the type may not be compatible + // with the existing attributes. Wipe out any problematic attributes. + RAttrs.remove(AttributeFuncs::typeIncompatible(NewRetTy)); + + LLVMContext &Ctx = Call.getContext(); + AI = Call.arg_begin(); + for (unsigned i = 0; i != NumCommonArgs; ++i, ++AI) { + Type *ParamTy = FT->getParamType(i); + + Value *NewArg = *AI; + if ((*AI)->getType() != ParamTy) + NewArg = Builder.CreateBitOrPointerCast(*AI, ParamTy); + Args.push_back(NewArg); + + // Add any parameter attributes except the ones incompatible with the new + // type. Note that we made sure all incompatible ones are safe to drop. + AttributeMask IncompatibleAttrs = AttributeFuncs::typeIncompatible( + ParamTy, AttributeFuncs::ASK_SAFE_TO_DROP); + if (CallerPAL.hasParamAttr(i, Attribute::ByVal) && + !ParamTy->isOpaquePointerTy()) { + AttrBuilder AB(Ctx, CallerPAL.getParamAttrs(i).removeAttributes( + Ctx, IncompatibleAttrs)); + AB.addByValAttr(ParamTy->getNonOpaquePointerElementType()); + ArgAttrs.push_back(AttributeSet::get(Ctx, AB)); + } else { + ArgAttrs.push_back( + CallerPAL.getParamAttrs(i).removeAttributes(Ctx, IncompatibleAttrs)); + } + } + + // If the function takes more arguments than the call was taking, add them + // now. + for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i) { + Args.push_back(Constant::getNullValue(FT->getParamType(i))); + ArgAttrs.push_back(AttributeSet()); + } + + // If we are removing arguments to the function, emit an obnoxious warning. + if (FT->getNumParams() < NumActualArgs) { + // TODO: if (!FT->isVarArg()) this call may be unreachable. PR14722 + if (FT->isVarArg()) { + // Add all of the arguments in their promoted form to the arg list. + for (unsigned i = FT->getNumParams(); i != NumActualArgs; ++i, ++AI) { + Type *PTy = getPromotedType((*AI)->getType()); + Value *NewArg = *AI; + if (PTy != (*AI)->getType()) { + // Must promote to pass through va_arg area! + Instruction::CastOps opcode = + CastInst::getCastOpcode(*AI, false, PTy, false); + NewArg = Builder.CreateCast(opcode, *AI, PTy); + } + Args.push_back(NewArg); + + // Add any parameter attributes. + ArgAttrs.push_back(CallerPAL.getParamAttrs(i)); + } + } + } + + AttributeSet FnAttrs = CallerPAL.getFnAttrs(); + + if (NewRetTy->isVoidTy()) + Caller->setName(""); // Void type should not have a name. + + assert((ArgAttrs.size() == FT->getNumParams() || FT->isVarArg()) && + "missing argument attributes"); + AttributeList NewCallerPAL = AttributeList::get( + Ctx, FnAttrs, AttributeSet::get(Ctx, RAttrs), ArgAttrs); + + SmallVector<OperandBundleDef, 1> OpBundles; + Call.getOperandBundlesAsDefs(OpBundles); + + CallBase *NewCall; + if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { + NewCall = Builder.CreateInvoke(Callee, II->getNormalDest(), + II->getUnwindDest(), Args, OpBundles); + } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(Caller)) { + NewCall = Builder.CreateCallBr(Callee, CBI->getDefaultDest(), + CBI->getIndirectDests(), Args, OpBundles); + } else { + NewCall = Builder.CreateCall(Callee, Args, OpBundles); + cast<CallInst>(NewCall)->setTailCallKind( + cast<CallInst>(Caller)->getTailCallKind()); + } + NewCall->takeName(Caller); + NewCall->setCallingConv(Call.getCallingConv()); + NewCall->setAttributes(NewCallerPAL); + + // Preserve prof metadata if any. + NewCall->copyMetadata(*Caller, {LLVMContext::MD_prof}); + + // Insert a cast of the return type as necessary. + Instruction *NC = NewCall; + Value *NV = NC; + if (OldRetTy != NV->getType() && !Caller->use_empty()) { + if (!NV->getType()->isVoidTy()) { + NV = NC = CastInst::CreateBitOrPointerCast(NC, OldRetTy); + NC->setDebugLoc(Caller->getDebugLoc()); + + // If this is an invoke/callbr instruction, we should insert it after the + // first non-phi instruction in the normal successor block. + if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { + BasicBlock::iterator I = II->getNormalDest()->getFirstInsertionPt(); + InsertNewInstBefore(NC, *I); + } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(Caller)) { + BasicBlock::iterator I = CBI->getDefaultDest()->getFirstInsertionPt(); + InsertNewInstBefore(NC, *I); + } else { + // Otherwise, it's a call, just insert cast right after the call. + InsertNewInstBefore(NC, *Caller); + } + Worklist.pushUsersToWorkList(*Caller); + } else { + NV = UndefValue::get(Caller->getType()); + } + } + + if (!Caller->use_empty()) + replaceInstUsesWith(*Caller, NV); + else if (Caller->hasValueHandle()) { + if (OldRetTy == NV->getType()) + ValueHandleBase::ValueIsRAUWd(Caller, NV); + else + // We cannot call ValueIsRAUWd with a different type, and the + // actual tracked value will disappear. + ValueHandleBase::ValueIsDeleted(Caller); + } + + eraseInstFromFunction(*Caller); + return true; +} + +/// Turn a call to a function created by init_trampoline / adjust_trampoline +/// intrinsic pair into a direct call to the underlying function. +Instruction * +InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, + IntrinsicInst &Tramp) { + Value *Callee = Call.getCalledOperand(); + Type *CalleeTy = Callee->getType(); + FunctionType *FTy = Call.getFunctionType(); + AttributeList Attrs = Call.getAttributes(); + + // If the call already has the 'nest' attribute somewhere then give up - + // otherwise 'nest' would occur twice after splicing in the chain. + if (Attrs.hasAttrSomewhere(Attribute::Nest)) + return nullptr; + + Function *NestF = cast<Function>(Tramp.getArgOperand(1)->stripPointerCasts()); + FunctionType *NestFTy = NestF->getFunctionType(); + + AttributeList NestAttrs = NestF->getAttributes(); + if (!NestAttrs.isEmpty()) { + unsigned NestArgNo = 0; + Type *NestTy = nullptr; + AttributeSet NestAttr; + + // Look for a parameter marked with the 'nest' attribute. + for (FunctionType::param_iterator I = NestFTy->param_begin(), + E = NestFTy->param_end(); + I != E; ++NestArgNo, ++I) { + AttributeSet AS = NestAttrs.getParamAttrs(NestArgNo); + if (AS.hasAttribute(Attribute::Nest)) { + // Record the parameter type and any other attributes. + NestTy = *I; + NestAttr = AS; + break; + } + } + + if (NestTy) { + std::vector<Value*> NewArgs; + std::vector<AttributeSet> NewArgAttrs; + NewArgs.reserve(Call.arg_size() + 1); + NewArgAttrs.reserve(Call.arg_size()); + + // Insert the nest argument into the call argument list, which may + // mean appending it. Likewise for attributes. + + { + unsigned ArgNo = 0; + auto I = Call.arg_begin(), E = Call.arg_end(); + do { + if (ArgNo == NestArgNo) { + // Add the chain argument and attributes. + Value *NestVal = Tramp.getArgOperand(2); + if (NestVal->getType() != NestTy) + NestVal = Builder.CreateBitCast(NestVal, NestTy, "nest"); + NewArgs.push_back(NestVal); + NewArgAttrs.push_back(NestAttr); + } + + if (I == E) + break; + + // Add the original argument and attributes. + NewArgs.push_back(*I); + NewArgAttrs.push_back(Attrs.getParamAttrs(ArgNo)); + + ++ArgNo; + ++I; + } while (true); + } + + // The trampoline may have been bitcast to a bogus type (FTy). + // Handle this by synthesizing a new function type, equal to FTy + // with the chain parameter inserted. + + std::vector<Type*> NewTypes; + NewTypes.reserve(FTy->getNumParams()+1); + + // Insert the chain's type into the list of parameter types, which may + // mean appending it. + { + unsigned ArgNo = 0; + FunctionType::param_iterator I = FTy->param_begin(), + E = FTy->param_end(); + + do { + if (ArgNo == NestArgNo) + // Add the chain's type. + NewTypes.push_back(NestTy); + + if (I == E) + break; + + // Add the original type. + NewTypes.push_back(*I); + + ++ArgNo; + ++I; + } while (true); + } + + // Replace the trampoline call with a direct call. Let the generic + // code sort out any function type mismatches. + FunctionType *NewFTy = FunctionType::get(FTy->getReturnType(), NewTypes, + FTy->isVarArg()); + Constant *NewCallee = + NestF->getType() == PointerType::getUnqual(NewFTy) ? + NestF : ConstantExpr::getBitCast(NestF, + PointerType::getUnqual(NewFTy)); + AttributeList NewPAL = + AttributeList::get(FTy->getContext(), Attrs.getFnAttrs(), + Attrs.getRetAttrs(), NewArgAttrs); + + SmallVector<OperandBundleDef, 1> OpBundles; + Call.getOperandBundlesAsDefs(OpBundles); + + Instruction *NewCaller; + if (InvokeInst *II = dyn_cast<InvokeInst>(&Call)) { + NewCaller = InvokeInst::Create(NewFTy, NewCallee, + II->getNormalDest(), II->getUnwindDest(), + NewArgs, OpBundles); + cast<InvokeInst>(NewCaller)->setCallingConv(II->getCallingConv()); + cast<InvokeInst>(NewCaller)->setAttributes(NewPAL); + } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(&Call)) { + NewCaller = + CallBrInst::Create(NewFTy, NewCallee, CBI->getDefaultDest(), + CBI->getIndirectDests(), NewArgs, OpBundles); + cast<CallBrInst>(NewCaller)->setCallingConv(CBI->getCallingConv()); + cast<CallBrInst>(NewCaller)->setAttributes(NewPAL); + } else { + NewCaller = CallInst::Create(NewFTy, NewCallee, NewArgs, OpBundles); + cast<CallInst>(NewCaller)->setTailCallKind( + cast<CallInst>(Call).getTailCallKind()); + cast<CallInst>(NewCaller)->setCallingConv( + cast<CallInst>(Call).getCallingConv()); + cast<CallInst>(NewCaller)->setAttributes(NewPAL); + } + NewCaller->setDebugLoc(Call.getDebugLoc()); + + return NewCaller; + } + } + + // Replace the trampoline call with a direct call. Since there is no 'nest' + // parameter, there is no need to adjust the argument list. Let the generic + // code sort out any function type mismatches. + Constant *NewCallee = ConstantExpr::getBitCast(NestF, CalleeTy); + Call.setCalledFunction(FTy, NewCallee); + return &Call; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp new file mode 100644 index 000000000000..e9e779b8619b --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -0,0 +1,2921 @@ +//===- InstCombineCasts.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the visit functions for cast operations. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "instcombine" + +/// Analyze 'Val', seeing if it is a simple linear expression. +/// If so, decompose it, returning some value X, such that Val is +/// X*Scale+Offset. +/// +static Value *decomposeSimpleLinearExpr(Value *Val, unsigned &Scale, + uint64_t &Offset) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(Val)) { + Offset = CI->getZExtValue(); + Scale = 0; + return ConstantInt::get(Val->getType(), 0); + } + + if (BinaryOperator *I = dyn_cast<BinaryOperator>(Val)) { + // Cannot look past anything that might overflow. + // We specifically require nuw because we store the Scale in an unsigned + // and perform an unsigned divide on it. + OverflowingBinaryOperator *OBI = dyn_cast<OverflowingBinaryOperator>(Val); + if (OBI && !OBI->hasNoUnsignedWrap()) { + Scale = 1; + Offset = 0; + return Val; + } + + if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) { + if (I->getOpcode() == Instruction::Shl) { + // This is a value scaled by '1 << the shift amt'. + Scale = UINT64_C(1) << RHS->getZExtValue(); + Offset = 0; + return I->getOperand(0); + } + + if (I->getOpcode() == Instruction::Mul) { + // This value is scaled by 'RHS'. + Scale = RHS->getZExtValue(); + Offset = 0; + return I->getOperand(0); + } + + if (I->getOpcode() == Instruction::Add) { + // We have X+C. Check to see if we really have (X*C2)+C1, + // where C1 is divisible by C2. + unsigned SubScale; + Value *SubVal = + decomposeSimpleLinearExpr(I->getOperand(0), SubScale, Offset); + Offset += RHS->getZExtValue(); + Scale = SubScale; + return SubVal; + } + } + } + + // Otherwise, we can't look past this. + Scale = 1; + Offset = 0; + return Val; +} + +/// If we find a cast of an allocation instruction, try to eliminate the cast by +/// moving the type information into the alloc. +Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI, + AllocaInst &AI) { + PointerType *PTy = cast<PointerType>(CI.getType()); + // Opaque pointers don't have an element type we could replace with. + if (PTy->isOpaque()) + return nullptr; + + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(&AI); + + // Get the type really allocated and the type casted to. + Type *AllocElTy = AI.getAllocatedType(); + Type *CastElTy = PTy->getNonOpaquePointerElementType(); + if (!AllocElTy->isSized() || !CastElTy->isSized()) return nullptr; + + // This optimisation does not work for cases where the cast type + // is scalable and the allocated type is not. This because we need to + // know how many times the casted type fits into the allocated type. + // For the opposite case where the allocated type is scalable and the + // cast type is not this leads to poor code quality due to the + // introduction of 'vscale' into the calculations. It seems better to + // bail out for this case too until we've done a proper cost-benefit + // analysis. + bool AllocIsScalable = isa<ScalableVectorType>(AllocElTy); + bool CastIsScalable = isa<ScalableVectorType>(CastElTy); + if (AllocIsScalable != CastIsScalable) return nullptr; + + Align AllocElTyAlign = DL.getABITypeAlign(AllocElTy); + Align CastElTyAlign = DL.getABITypeAlign(CastElTy); + if (CastElTyAlign < AllocElTyAlign) return nullptr; + + // If the allocation has multiple uses, only promote it if we are strictly + // increasing the alignment of the resultant allocation. If we keep it the + // same, we open the door to infinite loops of various kinds. + if (!AI.hasOneUse() && CastElTyAlign == AllocElTyAlign) return nullptr; + + // The alloc and cast types should be either both fixed or both scalable. + uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinSize(); + uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinSize(); + if (CastElTySize == 0 || AllocElTySize == 0) return nullptr; + + // If the allocation has multiple uses, only promote it if we're not + // shrinking the amount of memory being allocated. + uint64_t AllocElTyStoreSize = DL.getTypeStoreSize(AllocElTy).getKnownMinSize(); + uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinSize(); + if (!AI.hasOneUse() && CastElTyStoreSize < AllocElTyStoreSize) return nullptr; + + // See if we can satisfy the modulus by pulling a scale out of the array + // size argument. + unsigned ArraySizeScale; + uint64_t ArrayOffset; + Value *NumElements = // See if the array size is a decomposable linear expr. + decomposeSimpleLinearExpr(AI.getOperand(0), ArraySizeScale, ArrayOffset); + + // If we can now satisfy the modulus, by using a non-1 scale, we really can + // do the xform. + if ((AllocElTySize*ArraySizeScale) % CastElTySize != 0 || + (AllocElTySize*ArrayOffset ) % CastElTySize != 0) return nullptr; + + // We don't currently support arrays of scalable types. + assert(!AllocIsScalable || (ArrayOffset == 1 && ArraySizeScale == 0)); + + unsigned Scale = (AllocElTySize*ArraySizeScale)/CastElTySize; + Value *Amt = nullptr; + if (Scale == 1) { + Amt = NumElements; + } else { + Amt = ConstantInt::get(AI.getArraySize()->getType(), Scale); + // Insert before the alloca, not before the cast. + Amt = Builder.CreateMul(Amt, NumElements); + } + + if (uint64_t Offset = (AllocElTySize*ArrayOffset)/CastElTySize) { + Value *Off = ConstantInt::get(AI.getArraySize()->getType(), + Offset, true); + Amt = Builder.CreateAdd(Amt, Off); + } + + AllocaInst *New = Builder.CreateAlloca(CastElTy, AI.getAddressSpace(), Amt); + New->setAlignment(AI.getAlign()); + New->takeName(&AI); + New->setUsedWithInAlloca(AI.isUsedWithInAlloca()); + + // If the allocation has multiple real uses, insert a cast and change all + // things that used it to use the new cast. This will also hack on CI, but it + // will die soon. + if (!AI.hasOneUse()) { + // New is the allocation instruction, pointer typed. AI is the original + // allocation instruction, also pointer typed. Thus, cast to use is BitCast. + Value *NewCast = Builder.CreateBitCast(New, AI.getType(), "tmpcast"); + replaceInstUsesWith(AI, NewCast); + eraseInstFromFunction(AI); + } + return replaceInstUsesWith(CI, New); +} + +/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns +/// true for, actually insert the code to evaluate the expression. +Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, + bool isSigned) { + if (Constant *C = dyn_cast<Constant>(V)) { + C = ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/); + // If we got a constantexpr back, try to simplify it with DL info. + return ConstantFoldConstant(C, DL, &TLI); + } + + // Otherwise, it must be an instruction. + Instruction *I = cast<Instruction>(V); + Instruction *Res = nullptr; + unsigned Opc = I->getOpcode(); + switch (Opc) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::AShr: + case Instruction::LShr: + case Instruction::Shl: + case Instruction::UDiv: + case Instruction::URem: { + Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned); + Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned); + Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS); + break; + } + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + // If the source type of the cast is the type we're trying for then we can + // just return the source. There's no need to insert it because it is not + // new. + if (I->getOperand(0)->getType() == Ty) + return I->getOperand(0); + + // Otherwise, must be the same type of cast, so just reinsert a new one. + // This also handles the case of zext(trunc(x)) -> zext(x). + Res = CastInst::CreateIntegerCast(I->getOperand(0), Ty, + Opc == Instruction::SExt); + break; + case Instruction::Select: { + Value *True = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned); + Value *False = EvaluateInDifferentType(I->getOperand(2), Ty, isSigned); + Res = SelectInst::Create(I->getOperand(0), True, False); + break; + } + case Instruction::PHI: { + PHINode *OPN = cast<PHINode>(I); + PHINode *NPN = PHINode::Create(Ty, OPN->getNumIncomingValues()); + for (unsigned i = 0, e = OPN->getNumIncomingValues(); i != e; ++i) { + Value *V = + EvaluateInDifferentType(OPN->getIncomingValue(i), Ty, isSigned); + NPN->addIncoming(V, OPN->getIncomingBlock(i)); + } + Res = NPN; + break; + } + default: + // TODO: Can handle more cases here. + llvm_unreachable("Unreachable!"); + } + + Res->takeName(I); + return InsertNewInstWith(Res, *I); +} + +Instruction::CastOps +InstCombinerImpl::isEliminableCastPair(const CastInst *CI1, + const CastInst *CI2) { + Type *SrcTy = CI1->getSrcTy(); + Type *MidTy = CI1->getDestTy(); + Type *DstTy = CI2->getDestTy(); + + Instruction::CastOps firstOp = CI1->getOpcode(); + Instruction::CastOps secondOp = CI2->getOpcode(); + Type *SrcIntPtrTy = + SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr; + Type *MidIntPtrTy = + MidTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(MidTy) : nullptr; + Type *DstIntPtrTy = + DstTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(DstTy) : nullptr; + unsigned Res = CastInst::isEliminableCastPair(firstOp, secondOp, SrcTy, MidTy, + DstTy, SrcIntPtrTy, MidIntPtrTy, + DstIntPtrTy); + + // We don't want to form an inttoptr or ptrtoint that converts to an integer + // type that differs from the pointer size. + if ((Res == Instruction::IntToPtr && SrcTy != DstIntPtrTy) || + (Res == Instruction::PtrToInt && DstTy != SrcIntPtrTy)) + Res = 0; + + return Instruction::CastOps(Res); +} + +/// Implement the transforms common to all CastInst visitors. +Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { + Value *Src = CI.getOperand(0); + Type *Ty = CI.getType(); + + // Try to eliminate a cast of a cast. + if (auto *CSrc = dyn_cast<CastInst>(Src)) { // A->B->C cast + if (Instruction::CastOps NewOpc = isEliminableCastPair(CSrc, &CI)) { + // The first cast (CSrc) is eliminable so we need to fix up or replace + // the second cast (CI). CSrc will then have a good chance of being dead. + auto *Res = CastInst::Create(NewOpc, CSrc->getOperand(0), Ty); + // Point debug users of the dying cast to the new one. + if (CSrc->hasOneUse()) + replaceAllDbgUsesWith(*CSrc, *Res, CI, DT); + return Res; + } + } + + if (auto *Sel = dyn_cast<SelectInst>(Src)) { + // We are casting a select. Try to fold the cast into the select if the + // select does not have a compare instruction with matching operand types + // or the select is likely better done in a narrow type. + // Creating a select with operands that are different sizes than its + // condition may inhibit other folds and lead to worse codegen. + auto *Cmp = dyn_cast<CmpInst>(Sel->getCondition()); + if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType() || + (CI.getOpcode() == Instruction::Trunc && + shouldChangeType(CI.getSrcTy(), CI.getType()))) { + if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) { + replaceAllDbgUsesWith(*Sel, *NV, CI, DT); + return NV; + } + } + } + + // If we are casting a PHI, then fold the cast into the PHI. + if (auto *PN = dyn_cast<PHINode>(Src)) { + // Don't do this if it would create a PHI node with an illegal type from a + // legal type. + if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() || + shouldChangeType(CI.getSrcTy(), CI.getType())) + if (Instruction *NV = foldOpIntoPhi(CI, PN)) + return NV; + } + + // Canonicalize a unary shuffle after the cast if neither operation changes + // the size or element size of the input vector. + // TODO: We could allow size-changing ops if that doesn't harm codegen. + // cast (shuffle X, Mask) --> shuffle (cast X), Mask + Value *X; + ArrayRef<int> Mask; + if (match(Src, m_OneUse(m_Shuffle(m_Value(X), m_Undef(), m_Mask(Mask))))) { + // TODO: Allow scalable vectors? + auto *SrcTy = dyn_cast<FixedVectorType>(X->getType()); + auto *DestTy = dyn_cast<FixedVectorType>(Ty); + if (SrcTy && DestTy && + SrcTy->getNumElements() == DestTy->getNumElements() && + SrcTy->getPrimitiveSizeInBits() == DestTy->getPrimitiveSizeInBits()) { + Value *CastX = Builder.CreateCast(CI.getOpcode(), X, DestTy); + return new ShuffleVectorInst(CastX, Mask); + } + } + + return nullptr; +} + +/// Constants and extensions/truncates from the destination type are always +/// free to be evaluated in that type. This is a helper for canEvaluate*. +static bool canAlwaysEvaluateInType(Value *V, Type *Ty) { + if (isa<Constant>(V)) + return true; + Value *X; + if ((match(V, m_ZExtOrSExt(m_Value(X))) || match(V, m_Trunc(m_Value(X)))) && + X->getType() == Ty) + return true; + + return false; +} + +/// Filter out values that we can not evaluate in the destination type for free. +/// This is a helper for canEvaluate*. +static bool canNotEvaluateInType(Value *V, Type *Ty) { + assert(!isa<Constant>(V) && "Constant should already be handled."); + if (!isa<Instruction>(V)) + return true; + // We don't extend or shrink something that has multiple uses -- doing so + // would require duplicating the instruction which isn't profitable. + if (!V->hasOneUse()) + return true; + + return false; +} + +/// Return true if we can evaluate the specified expression tree as type Ty +/// instead of its larger type, and arrive with the same value. +/// This is used by code that tries to eliminate truncates. +/// +/// Ty will always be a type smaller than V. We should return true if trunc(V) +/// can be computed by computing V in the smaller type. If V is an instruction, +/// then trunc(inst(x,y)) can be computed as inst(trunc(x),trunc(y)), which only +/// makes sense if x and y can be efficiently truncated. +/// +/// This function works on both vectors and scalars. +/// +static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, + Instruction *CxtI) { + if (canAlwaysEvaluateInType(V, Ty)) + return true; + if (canNotEvaluateInType(V, Ty)) + return false; + + auto *I = cast<Instruction>(V); + Type *OrigTy = V->getType(); + switch (I->getOpcode()) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + // These operators can all arbitrarily be extended or truncated. + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); + + case Instruction::UDiv: + case Instruction::URem: { + // UDiv and URem can be truncated if all the truncated bits are zero. + uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + assert(BitWidth < OrigBitWidth && "Unexpected bitwidths!"); + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + if (IC.MaskedValueIsZero(I->getOperand(0), Mask, 0, CxtI) && + IC.MaskedValueIsZero(I->getOperand(1), Mask, 0, CxtI)) { + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); + } + break; + } + case Instruction::Shl: { + // If we are truncating the result of this SHL, and if it's a shift of an + // inrange amount, we can always perform a SHL in a smaller type. + uint32_t BitWidth = Ty->getScalarSizeInBits(); + KnownBits AmtKnownBits = + llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout()); + if (AmtKnownBits.getMaxValue().ult(BitWidth)) + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); + break; + } + case Instruction::LShr: { + // If this is a truncate of a logical shr, we can truncate it to a smaller + // lshr iff we know that the bits we would otherwise be shifting in are + // already zeros. + // TODO: It is enough to check that the bits we would be shifting in are + // zero - use AmtKnownBits.getMaxValue(). + uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + KnownBits AmtKnownBits = + llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout()); + APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + if (AmtKnownBits.getMaxValue().ult(BitWidth) && + IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, 0, CxtI)) { + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); + } + break; + } + case Instruction::AShr: { + // If this is a truncate of an arithmetic shr, we can truncate it to a + // smaller ashr iff we know that all the bits from the sign bit of the + // original type and the sign bit of the truncate type are similar. + // TODO: It is enough to check that the bits we would be shifting in are + // similar to sign bit of the truncate type. + uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + KnownBits AmtKnownBits = + llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout()); + unsigned ShiftedBits = OrigBitWidth - BitWidth; + if (AmtKnownBits.getMaxValue().ult(BitWidth) && + ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI)) + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); + break; + } + case Instruction::Trunc: + // trunc(trunc(x)) -> trunc(x) + return true; + case Instruction::ZExt: + case Instruction::SExt: + // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest + // trunc(ext(x)) -> trunc(x) if the source type is larger than the new dest + return true; + case Instruction::Select: { + SelectInst *SI = cast<SelectInst>(I); + return canEvaluateTruncated(SI->getTrueValue(), Ty, IC, CxtI) && + canEvaluateTruncated(SI->getFalseValue(), Ty, IC, CxtI); + } + case Instruction::PHI: { + // We can change a phi if we can change all operands. Note that we never + // get into trouble with cyclic PHIs here because we only consider + // instructions with a single use. + PHINode *PN = cast<PHINode>(I); + for (Value *IncValue : PN->incoming_values()) + if (!canEvaluateTruncated(IncValue, Ty, IC, CxtI)) + return false; + return true; + } + default: + // TODO: Can handle more cases here. + break; + } + + return false; +} + +/// Given a vector that is bitcast to an integer, optionally logically +/// right-shifted, and truncated, convert it to an extractelement. +/// Example (big endian): +/// trunc (lshr (bitcast <4 x i32> %X to i128), 32) to i32 +/// ---> +/// extractelement <4 x i32> %X, 1 +static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, + InstCombinerImpl &IC) { + Value *TruncOp = Trunc.getOperand(0); + Type *DestType = Trunc.getType(); + if (!TruncOp->hasOneUse() || !isa<IntegerType>(DestType)) + return nullptr; + + Value *VecInput = nullptr; + ConstantInt *ShiftVal = nullptr; + if (!match(TruncOp, m_CombineOr(m_BitCast(m_Value(VecInput)), + m_LShr(m_BitCast(m_Value(VecInput)), + m_ConstantInt(ShiftVal)))) || + !isa<VectorType>(VecInput->getType())) + return nullptr; + + VectorType *VecType = cast<VectorType>(VecInput->getType()); + unsigned VecWidth = VecType->getPrimitiveSizeInBits(); + unsigned DestWidth = DestType->getPrimitiveSizeInBits(); + unsigned ShiftAmount = ShiftVal ? ShiftVal->getZExtValue() : 0; + + if ((VecWidth % DestWidth != 0) || (ShiftAmount % DestWidth != 0)) + return nullptr; + + // If the element type of the vector doesn't match the result type, + // bitcast it to a vector type that we can extract from. + unsigned NumVecElts = VecWidth / DestWidth; + if (VecType->getElementType() != DestType) { + VecType = FixedVectorType::get(DestType, NumVecElts); + VecInput = IC.Builder.CreateBitCast(VecInput, VecType, "bc"); + } + + unsigned Elt = ShiftAmount / DestWidth; + if (IC.getDataLayout().isBigEndian()) + Elt = NumVecElts - 1 - Elt; + + return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt)); +} + +/// Funnel/Rotate left/right may occur in a wider type than necessary because of +/// type promotion rules. Try to narrow the inputs and convert to funnel shift. +Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) { + assert((isa<VectorType>(Trunc.getSrcTy()) || + shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) && + "Don't narrow to an illegal scalar type"); + + // Bail out on strange types. It is possible to handle some of these patterns + // even with non-power-of-2 sizes, but it is not a likely scenario. + Type *DestTy = Trunc.getType(); + unsigned NarrowWidth = DestTy->getScalarSizeInBits(); + unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits(); + if (!isPowerOf2_32(NarrowWidth)) + return nullptr; + + // First, find an or'd pair of opposite shifts: + // trunc (or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)) + BinaryOperator *Or0, *Or1; + if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1))))) + return nullptr; + + Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) || + Or0->getOpcode() == Or1->getOpcode()) + return nullptr; + + // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). + if (Or0->getOpcode() == BinaryOperator::LShr) { + std::swap(Or0, Or1); + std::swap(ShVal0, ShVal1); + std::swap(ShAmt0, ShAmt1); + } + assert(Or0->getOpcode() == BinaryOperator::Shl && + Or1->getOpcode() == BinaryOperator::LShr && + "Illegal or(shift,shift) pair"); + + // Match the shift amount operands for a funnel/rotate pattern. This always + // matches a subtraction on the R operand. + auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { + // The shift amounts may add up to the narrow bit width: + // (shl ShVal0, L) | (lshr ShVal1, Width - L) + // If this is a funnel shift (different operands are shifted), then the + // shift amount can not over-shift (create poison) in the narrow type. + unsigned MaxShiftAmountWidth = Log2_32(NarrowWidth); + APInt HiBitMask = ~APInt::getLowBitsSet(WideWidth, MaxShiftAmountWidth); + if (ShVal0 == ShVal1 || MaskedValueIsZero(L, HiBitMask)) + if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) + return L; + + // The following patterns currently only work for rotation patterns. + // TODO: Add more general funnel-shift compatible patterns. + if (ShVal0 != ShVal1) + return nullptr; + + // The shift amount may be masked with negation: + // (shl ShVal0, (X & (Width - 1))) | (lshr ShVal1, ((-X) & (Width - 1))) + Value *X; + unsigned Mask = Width - 1; + if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) && + match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))) + return X; + + // Same as above, but the shift amount may be extended after masking: + if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && + match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))) + return X; + + return nullptr; + }; + + Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, NarrowWidth); + bool IsFshl = true; // Sub on LSHR. + if (!ShAmt) { + ShAmt = matchShiftAmount(ShAmt1, ShAmt0, NarrowWidth); + IsFshl = false; // Sub on SHL. + } + if (!ShAmt) + return nullptr; + + // The right-shifted value must have high zeros in the wide type (for example + // from 'zext', 'and' or 'shift'). High bits of the left-shifted value are + // truncated, so those do not matter. + APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth); + if (!MaskedValueIsZero(ShVal1, HiBitMask, 0, &Trunc)) + return nullptr; + + // We have an unnecessarily wide rotate! + // trunc (or (shl ShVal0, ShAmt), (lshr ShVal1, BitWidth - ShAmt)) + // Narrow the inputs and convert to funnel shift intrinsic: + // llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt)) + Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); + Value *X, *Y; + X = Y = Builder.CreateTrunc(ShVal0, DestTy); + if (ShVal0 != ShVal1) + Y = Builder.CreateTrunc(ShVal1, DestTy); + Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; + Function *F = Intrinsic::getDeclaration(Trunc.getModule(), IID, DestTy); + return CallInst::Create(F, {X, Y, NarrowShAmt}); +} + +/// Try to narrow the width of math or bitwise logic instructions by pulling a +/// truncate ahead of binary operators. +Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) { + Type *SrcTy = Trunc.getSrcTy(); + Type *DestTy = Trunc.getType(); + unsigned SrcWidth = SrcTy->getScalarSizeInBits(); + unsigned DestWidth = DestTy->getScalarSizeInBits(); + + if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) + return nullptr; + + BinaryOperator *BinOp; + if (!match(Trunc.getOperand(0), m_OneUse(m_BinOp(BinOp)))) + return nullptr; + + Value *BinOp0 = BinOp->getOperand(0); + Value *BinOp1 = BinOp->getOperand(1); + switch (BinOp->getOpcode()) { + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: { + Constant *C; + if (match(BinOp0, m_Constant(C))) { + // trunc (binop C, X) --> binop (trunc C', X) + Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); + Value *TruncX = Builder.CreateTrunc(BinOp1, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), NarrowC, TruncX); + } + if (match(BinOp1, m_Constant(C))) { + // trunc (binop X, C) --> binop (trunc X, C') + Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); + Value *TruncX = Builder.CreateTrunc(BinOp0, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), TruncX, NarrowC); + } + Value *X; + if (match(BinOp0, m_ZExtOrSExt(m_Value(X))) && X->getType() == DestTy) { + // trunc (binop (ext X), Y) --> binop X, (trunc Y) + Value *NarrowOp1 = Builder.CreateTrunc(BinOp1, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), X, NarrowOp1); + } + if (match(BinOp1, m_ZExtOrSExt(m_Value(X))) && X->getType() == DestTy) { + // trunc (binop Y, (ext X)) --> binop (trunc Y), X + Value *NarrowOp0 = Builder.CreateTrunc(BinOp0, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), NarrowOp0, X); + } + break; + } + case Instruction::LShr: + case Instruction::AShr: { + // trunc (*shr (trunc A), C) --> trunc(*shr A, C) + Value *A; + Constant *C; + if (match(BinOp0, m_Trunc(m_Value(A))) && match(BinOp1, m_Constant(C))) { + unsigned MaxShiftAmt = SrcWidth - DestWidth; + // If the shift is small enough, all zero/sign bits created by the shift + // are removed by the trunc. + if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, + APInt(SrcWidth, MaxShiftAmt)))) { + auto *OldShift = cast<Instruction>(Trunc.getOperand(0)); + bool IsExact = OldShift->isExact(); + auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); + ShAmt = Constant::mergeUndefsWith(ShAmt, C); + Value *Shift = + OldShift->getOpcode() == Instruction::AShr + ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) + : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact); + return CastInst::CreateTruncOrBitCast(Shift, DestTy); + } + } + break; + } + default: break; + } + + if (Instruction *NarrowOr = narrowFunnelShift(Trunc)) + return NarrowOr; + + return nullptr; +} + +/// Try to narrow the width of a splat shuffle. This could be generalized to any +/// shuffle with a constant operand, but we limit the transform to avoid +/// creating a shuffle type that targets may not be able to lower effectively. +static Instruction *shrinkSplatShuffle(TruncInst &Trunc, + InstCombiner::BuilderTy &Builder) { + auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0)); + if (Shuf && Shuf->hasOneUse() && match(Shuf->getOperand(1), m_Undef()) && + is_splat(Shuf->getShuffleMask()) && + Shuf->getType() == Shuf->getOperand(0)->getType()) { + // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask + // trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask + Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType()); + return new ShuffleVectorInst(NarrowOp, Shuf->getShuffleMask()); + } + + return nullptr; +} + +/// Try to narrow the width of an insert element. This could be generalized for +/// any vector constant, but we limit the transform to insertion into undef to +/// avoid potential backend problems from unsupported insertion widths. This +/// could also be extended to handle the case of inserting a scalar constant +/// into a vector variable. +static Instruction *shrinkInsertElt(CastInst &Trunc, + InstCombiner::BuilderTy &Builder) { + Instruction::CastOps Opcode = Trunc.getOpcode(); + assert((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) && + "Unexpected instruction for shrinking"); + + auto *InsElt = dyn_cast<InsertElementInst>(Trunc.getOperand(0)); + if (!InsElt || !InsElt->hasOneUse()) + return nullptr; + + Type *DestTy = Trunc.getType(); + Type *DestScalarTy = DestTy->getScalarType(); + Value *VecOp = InsElt->getOperand(0); + Value *ScalarOp = InsElt->getOperand(1); + Value *Index = InsElt->getOperand(2); + + if (match(VecOp, m_Undef())) { + // trunc (inselt undef, X, Index) --> inselt undef, (trunc X), Index + // fptrunc (inselt undef, X, Index) --> inselt undef, (fptrunc X), Index + UndefValue *NarrowUndef = UndefValue::get(DestTy); + Value *NarrowOp = Builder.CreateCast(Opcode, ScalarOp, DestScalarTy); + return InsertElementInst::Create(NarrowUndef, NarrowOp, Index); + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { + if (Instruction *Result = commonCastTransforms(Trunc)) + return Result; + + Value *Src = Trunc.getOperand(0); + Type *DestTy = Trunc.getType(), *SrcTy = Src->getType(); + unsigned DestWidth = DestTy->getScalarSizeInBits(); + unsigned SrcWidth = SrcTy->getScalarSizeInBits(); + + // Attempt to truncate the entire input expression tree to the destination + // type. Only do this if the dest type is a simple type, don't convert the + // expression tree to something weird like i93 unless the source is also + // strange. + if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && + canEvaluateTruncated(Src, DestTy, *this, &Trunc)) { + + // If this cast is a truncate, evaluting in a different type always + // eliminates the cast, so it is always a win. + LLVM_DEBUG( + dbgs() << "ICE: EvaluateInDifferentType converting expression type" + " to avoid cast: " + << Trunc << '\n'); + Value *Res = EvaluateInDifferentType(Src, DestTy, false); + assert(Res->getType() == DestTy); + return replaceInstUsesWith(Trunc, Res); + } + + // For integer types, check if we can shorten the entire input expression to + // DestWidth * 2, which won't allow removing the truncate, but reducing the + // width may enable further optimizations, e.g. allowing for larger + // vectorization factors. + if (auto *DestITy = dyn_cast<IntegerType>(DestTy)) { + if (DestWidth * 2 < SrcWidth) { + auto *NewDestTy = DestITy->getExtendedType(); + if (shouldChangeType(SrcTy, NewDestTy) && + canEvaluateTruncated(Src, NewDestTy, *this, &Trunc)) { + LLVM_DEBUG( + dbgs() << "ICE: EvaluateInDifferentType converting expression type" + " to reduce the width of operand of" + << Trunc << '\n'); + Value *Res = EvaluateInDifferentType(Src, NewDestTy, false); + return new TruncInst(Res, DestTy); + } + } + } + + // Test if the trunc is the user of a select which is part of a + // minimum or maximum operation. If so, don't do any more simplification. + // Even simplifying demanded bits can break the canonical form of a + // min/max. + Value *LHS, *RHS; + if (SelectInst *Sel = dyn_cast<SelectInst>(Src)) + if (matchSelectPattern(Sel, LHS, RHS).Flavor != SPF_UNKNOWN) + return nullptr; + + // See if we can simplify any instructions used by the input whose sole + // purpose is to compute bits we don't care about. + if (SimplifyDemandedInstructionBits(Trunc)) + return &Trunc; + + if (DestWidth == 1) { + Value *Zero = Constant::getNullValue(SrcTy); + if (DestTy->isIntegerTy()) { + // Canonicalize trunc x to i1 -> icmp ne (and x, 1), 0 (scalar only). + // TODO: We canonicalize to more instructions here because we are probably + // lacking equivalent analysis for trunc relative to icmp. There may also + // be codegen concerns. If those trunc limitations were removed, we could + // remove this transform. + Value *And = Builder.CreateAnd(Src, ConstantInt::get(SrcTy, 1)); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + + // For vectors, we do not canonicalize all truncs to icmp, so optimize + // patterns that would be covered within visitICmpInst. + Value *X; + Constant *C; + if (match(Src, m_OneUse(m_LShr(m_Value(X), m_Constant(C))))) { + // trunc (lshr X, C) to i1 --> icmp ne (and X, C'), 0 + Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1)); + Constant *MaskC = ConstantExpr::getShl(One, C); + Value *And = Builder.CreateAnd(X, MaskC); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_Constant(C)), + m_Deferred(X))))) { + // trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0 + Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1)); + Constant *MaskC = ConstantExpr::getShl(One, C); + MaskC = ConstantExpr::getOr(MaskC, One); + Value *And = Builder.CreateAnd(X, MaskC); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + } + + Value *A, *B; + Constant *C; + if (match(Src, m_LShr(m_SExt(m_Value(A)), m_Constant(C)))) { + unsigned AWidth = A->getType()->getScalarSizeInBits(); + unsigned MaxShiftAmt = SrcWidth - std::max(DestWidth, AWidth); + auto *OldSh = cast<Instruction>(Src); + bool IsExact = OldSh->isExact(); + + // If the shift is small enough, all zero bits created by the shift are + // removed by the trunc. + if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, + APInt(SrcWidth, MaxShiftAmt)))) { + // trunc (lshr (sext A), C) --> ashr A, C + if (A->getType() == DestTy) { + Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false); + Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); + ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); + ShAmt = Constant::mergeUndefsWith(ShAmt, C); + return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt) + : BinaryOperator::CreateAShr(A, ShAmt); + } + // The types are mismatched, so create a cast after shifting: + // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C) + if (Src->hasOneUse()) { + Constant *MaxAmt = ConstantInt::get(SrcTy, AWidth - 1, false); + Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); + ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); + Value *Shift = Builder.CreateAShr(A, ShAmt, "", IsExact); + return CastInst::CreateIntegerCast(Shift, DestTy, true); + } + } + // TODO: Mask high bits with 'and'. + } + + if (Instruction *I = narrowBinOp(Trunc)) + return I; + + if (Instruction *I = shrinkSplatShuffle(Trunc, Builder)) + return I; + + if (Instruction *I = shrinkInsertElt(Trunc, Builder)) + return I; + + if (Src->hasOneUse() && + (isa<VectorType>(SrcTy) || shouldChangeType(SrcTy, DestTy))) { + // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the + // dest type is native and cst < dest size. + if (match(Src, m_Shl(m_Value(A), m_Constant(C))) && + !match(A, m_Shr(m_Value(), m_Constant()))) { + // Skip shifts of shift by constants. It undoes a combine in + // FoldShiftByConstant and is the extend in reg pattern. + APInt Threshold = APInt(C->getType()->getScalarSizeInBits(), DestWidth); + if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold))) { + Value *NewTrunc = Builder.CreateTrunc(A, DestTy, A->getName() + ".tr"); + return BinaryOperator::Create(Instruction::Shl, NewTrunc, + ConstantExpr::getTrunc(C, DestTy)); + } + } + } + + if (Instruction *I = foldVecTruncToExtElt(Trunc, *this)) + return I; + + // Whenever an element is extracted from a vector, and then truncated, + // canonicalize by converting it to a bitcast followed by an + // extractelement. + // + // Example (little endian): + // trunc (extractelement <4 x i64> %X, 0) to i32 + // ---> + // extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0 + Value *VecOp; + ConstantInt *Cst; + if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) { + auto *VecOpTy = cast<VectorType>(VecOp->getType()); + auto VecElts = VecOpTy->getElementCount(); + + // A badly fit destination size would result in an invalid cast. + if (SrcWidth % DestWidth == 0) { + uint64_t TruncRatio = SrcWidth / DestWidth; + uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio; + uint64_t VecOpIdx = Cst->getZExtValue(); + uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1 + : VecOpIdx * TruncRatio; + assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() && + "overflow 32-bits"); + + auto *BitCastTo = + VectorType::get(DestTy, BitCastNumElts, VecElts.isScalable()); + Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo); + return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx)); + } + } + + // trunc (ctlz_i32(zext(A), B) --> add(ctlz_i16(A, B), C) + if (match(Src, m_OneUse(m_Intrinsic<Intrinsic::ctlz>(m_ZExt(m_Value(A)), + m_Value(B))))) { + unsigned AWidth = A->getType()->getScalarSizeInBits(); + if (AWidth == DestWidth && AWidth > Log2_32(SrcWidth)) { + Value *WidthDiff = ConstantInt::get(A->getType(), SrcWidth - AWidth); + Value *NarrowCtlz = + Builder.CreateIntrinsic(Intrinsic::ctlz, {Trunc.getType()}, {A, B}); + return BinaryOperator::CreateAdd(NarrowCtlz, WidthDiff); + } + } + + if (match(Src, m_VScale(DL))) { + if (Trunc.getFunction() && + Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + Attribute Attr = + Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { + if (Log2_32(*MaxVScale) < DestWidth) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(Trunc, VScale); + } + } + } + } + + return nullptr; +} + +Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) { + // If we are just checking for a icmp eq of a single bit and zext'ing it + // to an integer, then shift the bit to the appropriate place and then + // cast to integer to avoid the comparison. + + // FIXME: This set of transforms does not check for extra uses and/or creates + // an extra instruction (an optional final cast is not included + // in the transform comments). We may also want to favor icmp over + // shifts in cases of equal instructions because icmp has better + // analysis in general (invert the transform). + + const APInt *Op1CV; + if (match(Cmp->getOperand(1), m_APInt(Op1CV))) { + + // zext (x <s 0) to i32 --> x>>u31 true if signbit set. + if (Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isZero()) { + Value *In = Cmp->getOperand(0); + Value *Sh = ConstantInt::get(In->getType(), + In->getType()->getScalarSizeInBits() - 1); + In = Builder.CreateLShr(In, Sh, In->getName() + ".lobit"); + if (In->getType() != Zext.getType()) + In = Builder.CreateIntCast(In, Zext.getType(), false /*ZExt*/); + + return replaceInstUsesWith(Zext, In); + } + + // zext (X == 0) to i32 --> X^1 iff X has only the low bit set. + // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. + // zext (X == 1) to i32 --> X iff X has only the low bit set. + // zext (X == 2) to i32 --> X>>1 iff X has only the 2nd bit set. + // zext (X != 0) to i32 --> X iff X has only the low bit set. + // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. + // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. + // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. + if ((Op1CV->isZero() || Op1CV->isPowerOf2()) && + // This only works for EQ and NE + Cmp->isEquality()) { + // If Op1C some other power of two, convert: + KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext); + + APInt KnownZeroMask(~Known.Zero); + if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? + bool isNE = Cmp->getPredicate() == ICmpInst::ICMP_NE; + if (!Op1CV->isZero() && (*Op1CV != KnownZeroMask)) { + // (X&4) == 2 --> false + // (X&4) != 2 --> true + Constant *Res = ConstantInt::get(Zext.getType(), isNE); + return replaceInstUsesWith(Zext, Res); + } + + uint32_t ShAmt = KnownZeroMask.logBase2(); + Value *In = Cmp->getOperand(0); + if (ShAmt) { + // Perform a logical shr by shiftamt. + // Insert the shift to put the result in the low bit. + In = Builder.CreateLShr(In, ConstantInt::get(In->getType(), ShAmt), + In->getName() + ".lobit"); + } + + if (!Op1CV->isZero() == isNE) { // Toggle the low bit. + Constant *One = ConstantInt::get(In->getType(), 1); + In = Builder.CreateXor(In, One); + } + + if (Zext.getType() == In->getType()) + return replaceInstUsesWith(Zext, In); + + Value *IntCast = Builder.CreateIntCast(In, Zext.getType(), false); + return replaceInstUsesWith(Zext, IntCast); + } + } + } + + if (Cmp->isEquality() && Zext.getType() == Cmp->getOperand(0)->getType()) { + // Test if a bit is clear/set using a shifted-one mask: + // zext (icmp eq (and X, (1 << ShAmt)), 0) --> and (lshr (not X), ShAmt), 1 + // zext (icmp ne (and X, (1 << ShAmt)), 0) --> and (lshr X, ShAmt), 1 + Value *X, *ShAmt; + if (Cmp->hasOneUse() && match(Cmp->getOperand(1), m_ZeroInt()) && + match(Cmp->getOperand(0), + m_OneUse(m_c_And(m_Shl(m_One(), m_Value(ShAmt)), m_Value(X))))) { + if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) + X = Builder.CreateNot(X); + Value *Lshr = Builder.CreateLShr(X, ShAmt); + Value *And1 = Builder.CreateAnd(Lshr, ConstantInt::get(X->getType(), 1)); + return replaceInstUsesWith(Zext, And1); + } + + // icmp ne A, B is equal to xor A, B when A and B only really have one bit. + // It is also profitable to transform icmp eq into not(xor(A, B)) because + // that may lead to additional simplifications. + if (IntegerType *ITy = dyn_cast<IntegerType>(Zext.getType())) { + Value *LHS = Cmp->getOperand(0); + Value *RHS = Cmp->getOperand(1); + + KnownBits KnownLHS = computeKnownBits(LHS, 0, &Zext); + KnownBits KnownRHS = computeKnownBits(RHS, 0, &Zext); + + if (KnownLHS == KnownRHS) { + APInt KnownBits = KnownLHS.Zero | KnownLHS.One; + APInt UnknownBit = ~KnownBits; + if (UnknownBit.countPopulation() == 1) { + Value *Result = Builder.CreateXor(LHS, RHS); + + // Mask off any bits that are set and won't be shifted away. + if (KnownLHS.One.uge(UnknownBit)) + Result = Builder.CreateAnd(Result, + ConstantInt::get(ITy, UnknownBit)); + + // Shift the bit we're testing down to the lsb. + Result = Builder.CreateLShr( + Result, ConstantInt::get(ITy, UnknownBit.countTrailingZeros())); + + if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) + Result = Builder.CreateXor(Result, ConstantInt::get(ITy, 1)); + Result->takeName(Cmp); + return replaceInstUsesWith(Zext, Result); + } + } + } + } + + return nullptr; +} + +/// Determine if the specified value can be computed in the specified wider type +/// and produce the same low bits. If not, return false. +/// +/// If this function returns true, it can also return a non-zero number of bits +/// (in BitsToClear) which indicates that the value it computes is correct for +/// the zero extend, but that the additional BitsToClear bits need to be zero'd +/// out. For example, to promote something like: +/// +/// %B = trunc i64 %A to i32 +/// %C = lshr i32 %B, 8 +/// %E = zext i32 %C to i64 +/// +/// CanEvaluateZExtd for the 'lshr' will return true, and BitsToClear will be +/// set to 8 to indicate that the promoted value needs to have bits 24-31 +/// cleared in addition to bits 32-63. Since an 'and' will be generated to +/// clear the top bits anyway, doing this has no extra cost. +/// +/// This function works on both vectors and scalars. +static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, + InstCombinerImpl &IC, Instruction *CxtI) { + BitsToClear = 0; + if (canAlwaysEvaluateInType(V, Ty)) + return true; + if (canNotEvaluateInType(V, Ty)) + return false; + + auto *I = cast<Instruction>(V); + unsigned Tmp; + switch (I->getOpcode()) { + case Instruction::ZExt: // zext(zext(x)) -> zext(x). + case Instruction::SExt: // zext(sext(x)) -> sext(x). + case Instruction::Trunc: // zext(trunc(x)) -> trunc(x) or zext(x) + return true; + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI) || + !canEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI)) + return false; + // These can all be promoted if neither operand has 'bits to clear'. + if (BitsToClear == 0 && Tmp == 0) + return true; + + // If the operation is an AND/OR/XOR and the bits to clear are zero in the + // other side, BitsToClear is ok. + if (Tmp == 0 && I->isBitwiseLogicOp()) { + // We use MaskedValueIsZero here for generality, but the case we care + // about the most is constant RHS. + unsigned VSize = V->getType()->getScalarSizeInBits(); + if (IC.MaskedValueIsZero(I->getOperand(1), + APInt::getHighBitsSet(VSize, BitsToClear), + 0, CxtI)) { + // If this is an And instruction and all of the BitsToClear are + // known to be zero we can reset BitsToClear. + if (I->getOpcode() == Instruction::And) + BitsToClear = 0; + return true; + } + } + + // Otherwise, we don't know how to analyze this BitsToClear case yet. + return false; + + case Instruction::Shl: { + // We can promote shl(x, cst) if we can promote x. Since shl overwrites the + // upper bits we can reduce BitsToClear by the shift amount. + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { + if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) + return false; + uint64_t ShiftAmt = Amt->getZExtValue(); + BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0; + return true; + } + return false; + } + case Instruction::LShr: { + // We can promote lshr(x, cst) if we can promote x. This requires the + // ultimate 'and' to clear out the high zero bits we're clearing out though. + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { + if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) + return false; + BitsToClear += Amt->getZExtValue(); + if (BitsToClear > V->getType()->getScalarSizeInBits()) + BitsToClear = V->getType()->getScalarSizeInBits(); + return true; + } + // Cannot promote variable LSHR. + return false; + } + case Instruction::Select: + if (!canEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI) || + !canEvaluateZExtd(I->getOperand(2), Ty, BitsToClear, IC, CxtI) || + // TODO: If important, we could handle the case when the BitsToClear are + // known zero in the disagreeing side. + Tmp != BitsToClear) + return false; + return true; + + case Instruction::PHI: { + // We can change a phi if we can change all operands. Note that we never + // get into trouble with cyclic PHIs here because we only consider + // instructions with a single use. + PHINode *PN = cast<PHINode>(I); + if (!canEvaluateZExtd(PN->getIncomingValue(0), Ty, BitsToClear, IC, CxtI)) + return false; + for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i) + if (!canEvaluateZExtd(PN->getIncomingValue(i), Ty, Tmp, IC, CxtI) || + // TODO: If important, we could handle the case when the BitsToClear + // are known zero in the disagreeing input. + Tmp != BitsToClear) + return false; + return true; + } + default: + // TODO: Can handle more cases here. + return false; + } +} + +Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { + // If this zero extend is only used by a truncate, let the truncate be + // eliminated before we try to optimize this zext. + if (CI.hasOneUse() && isa<TruncInst>(CI.user_back())) + return nullptr; + + // If one of the common conversion will work, do it. + if (Instruction *Result = commonCastTransforms(CI)) + return Result; + + Value *Src = CI.getOperand(0); + Type *SrcTy = Src->getType(), *DestTy = CI.getType(); + + // Try to extend the entire expression tree to the wide destination type. + unsigned BitsToClear; + if (shouldChangeType(SrcTy, DestTy) && + canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) { + assert(BitsToClear <= SrcTy->getScalarSizeInBits() && + "Can't clear more bits than in SrcTy"); + + // Okay, we can transform this! Insert the new expression now. + LLVM_DEBUG( + dbgs() << "ICE: EvaluateInDifferentType converting expression type" + " to avoid zero extend: " + << CI << '\n'); + Value *Res = EvaluateInDifferentType(Src, DestTy, false); + assert(Res->getType() == DestTy); + + // Preserve debug values referring to Src if the zext is its last use. + if (auto *SrcOp = dyn_cast<Instruction>(Src)) + if (SrcOp->hasOneUse()) + replaceAllDbgUsesWith(*SrcOp, *Res, CI, DT); + + uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits()-BitsToClear; + uint32_t DestBitSize = DestTy->getScalarSizeInBits(); + + // If the high bits are already filled with zeros, just replace this + // cast with the result. + if (MaskedValueIsZero(Res, + APInt::getHighBitsSet(DestBitSize, + DestBitSize-SrcBitsKept), + 0, &CI)) + return replaceInstUsesWith(CI, Res); + + // We need to emit an AND to clear the high bits. + Constant *C = ConstantInt::get(Res->getType(), + APInt::getLowBitsSet(DestBitSize, SrcBitsKept)); + return BinaryOperator::CreateAnd(Res, C); + } + + // If this is a TRUNC followed by a ZEXT then we are dealing with integral + // types and if the sizes are just right we can convert this into a logical + // 'and' which will be much cheaper than the pair of casts. + if (TruncInst *CSrc = dyn_cast<TruncInst>(Src)) { // A->B->C cast + // TODO: Subsume this into EvaluateInDifferentType. + + // Get the sizes of the types involved. We know that the intermediate type + // will be smaller than A or C, but don't know the relation between A and C. + Value *A = CSrc->getOperand(0); + unsigned SrcSize = A->getType()->getScalarSizeInBits(); + unsigned MidSize = CSrc->getType()->getScalarSizeInBits(); + unsigned DstSize = CI.getType()->getScalarSizeInBits(); + // If we're actually extending zero bits, then if + // SrcSize < DstSize: zext(a & mask) + // SrcSize == DstSize: a & mask + // SrcSize > DstSize: trunc(a) & mask + if (SrcSize < DstSize) { + APInt AndValue(APInt::getLowBitsSet(SrcSize, MidSize)); + Constant *AndConst = ConstantInt::get(A->getType(), AndValue); + Value *And = Builder.CreateAnd(A, AndConst, CSrc->getName() + ".mask"); + return new ZExtInst(And, CI.getType()); + } + + if (SrcSize == DstSize) { + APInt AndValue(APInt::getLowBitsSet(SrcSize, MidSize)); + return BinaryOperator::CreateAnd(A, ConstantInt::get(A->getType(), + AndValue)); + } + if (SrcSize > DstSize) { + Value *Trunc = Builder.CreateTrunc(A, CI.getType()); + APInt AndValue(APInt::getLowBitsSet(DstSize, MidSize)); + return BinaryOperator::CreateAnd(Trunc, + ConstantInt::get(Trunc->getType(), + AndValue)); + } + } + + if (ICmpInst *Cmp = dyn_cast<ICmpInst>(Src)) + return transformZExtICmp(Cmp, CI); + + // zext(trunc(X) & C) -> (X & zext(C)). + Constant *C; + Value *X; + if (match(Src, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) && + X->getType() == CI.getType()) + return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, CI.getType())); + + // zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)). + Value *And; + if (match(Src, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) && + match(And, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Specific(C)))) && + X->getType() == CI.getType()) { + Constant *ZC = ConstantExpr::getZExt(C, CI.getType()); + return BinaryOperator::CreateXor(Builder.CreateAnd(X, ZC), ZC); + } + + if (match(Src, m_VScale(DL))) { + if (CI.getFunction() && + CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { + unsigned TypeWidth = Src->getType()->getScalarSizeInBits(); + if (Log2_32(*MaxVScale) < TypeWidth) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(CI, VScale); + } + } + } + } + + return nullptr; +} + +/// Transform (sext icmp) to bitwise / integer operations to eliminate the icmp. +Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *ICI, + Instruction &CI) { + Value *Op0 = ICI->getOperand(0), *Op1 = ICI->getOperand(1); + ICmpInst::Predicate Pred = ICI->getPredicate(); + + // Don't bother if Op1 isn't of vector or integer type. + if (!Op1->getType()->isIntOrIntVectorTy()) + return nullptr; + + if ((Pred == ICmpInst::ICMP_SLT && match(Op1, m_ZeroInt())) || + (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes()))) { + // (x <s 0) ? -1 : 0 -> ashr x, 31 -> all ones if negative + // (x >s -1) ? -1 : 0 -> not (ashr x, 31) -> all ones if positive + Value *Sh = ConstantInt::get(Op0->getType(), + Op0->getType()->getScalarSizeInBits() - 1); + Value *In = Builder.CreateAShr(Op0, Sh, Op0->getName() + ".lobit"); + if (In->getType() != CI.getType()) + In = Builder.CreateIntCast(In, CI.getType(), true /*SExt*/); + + if (Pred == ICmpInst::ICMP_SGT) + In = Builder.CreateNot(In, In->getName() + ".not"); + return replaceInstUsesWith(CI, In); + } + + if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { + // If we know that only one bit of the LHS of the icmp can be set and we + // have an equality comparison with zero or a power of 2, we can transform + // the icmp and sext into bitwise/integer operations. + if (ICI->hasOneUse() && + ICI->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){ + KnownBits Known = computeKnownBits(Op0, 0, &CI); + + APInt KnownZeroMask(~Known.Zero); + if (KnownZeroMask.isPowerOf2()) { + Value *In = ICI->getOperand(0); + + // If the icmp tests for a known zero bit we can constant fold it. + if (!Op1C->isZero() && Op1C->getValue() != KnownZeroMask) { + Value *V = Pred == ICmpInst::ICMP_NE ? + ConstantInt::getAllOnesValue(CI.getType()) : + ConstantInt::getNullValue(CI.getType()); + return replaceInstUsesWith(CI, V); + } + + if (!Op1C->isZero() == (Pred == ICmpInst::ICMP_NE)) { + // sext ((x & 2^n) == 0) -> (x >> n) - 1 + // sext ((x & 2^n) != 2^n) -> (x >> n) - 1 + unsigned ShiftAmt = KnownZeroMask.countTrailingZeros(); + // Perform a right shift to place the desired bit in the LSB. + if (ShiftAmt) + In = Builder.CreateLShr(In, + ConstantInt::get(In->getType(), ShiftAmt)); + + // At this point "In" is either 1 or 0. Subtract 1 to turn + // {1, 0} -> {0, -1}. + In = Builder.CreateAdd(In, + ConstantInt::getAllOnesValue(In->getType()), + "sext"); + } else { + // sext ((x & 2^n) != 0) -> (x << bitwidth-n) a>> bitwidth-1 + // sext ((x & 2^n) == 2^n) -> (x << bitwidth-n) a>> bitwidth-1 + unsigned ShiftAmt = KnownZeroMask.countLeadingZeros(); + // Perform a left shift to place the desired bit in the MSB. + if (ShiftAmt) + In = Builder.CreateShl(In, + ConstantInt::get(In->getType(), ShiftAmt)); + + // Distribute the bit over the whole bit width. + In = Builder.CreateAShr(In, ConstantInt::get(In->getType(), + KnownZeroMask.getBitWidth() - 1), "sext"); + } + + if (CI.getType() == In->getType()) + return replaceInstUsesWith(CI, In); + return CastInst::CreateIntegerCast(In, CI.getType(), true/*SExt*/); + } + } + } + + return nullptr; +} + +/// Return true if we can take the specified value and return it as type Ty +/// without inserting any new casts and without changing the value of the common +/// low bits. This is used by code that tries to promote integer operations to +/// a wider types will allow us to eliminate the extension. +/// +/// This function works on both vectors and scalars. +/// +static bool canEvaluateSExtd(Value *V, Type *Ty) { + assert(V->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits() && + "Can't sign extend type to a smaller type"); + if (canAlwaysEvaluateInType(V, Ty)) + return true; + if (canNotEvaluateInType(V, Ty)) + return false; + + auto *I = cast<Instruction>(V); + switch (I->getOpcode()) { + case Instruction::SExt: // sext(sext(x)) -> sext(x) + case Instruction::ZExt: // sext(zext(x)) -> zext(x) + case Instruction::Trunc: // sext(trunc(x)) -> trunc(x) or sext(x) + return true; + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + // These operators can all arbitrarily be extended if their inputs can. + return canEvaluateSExtd(I->getOperand(0), Ty) && + canEvaluateSExtd(I->getOperand(1), Ty); + + //case Instruction::Shl: TODO + //case Instruction::LShr: TODO + + case Instruction::Select: + return canEvaluateSExtd(I->getOperand(1), Ty) && + canEvaluateSExtd(I->getOperand(2), Ty); + + case Instruction::PHI: { + // We can change a phi if we can change all operands. Note that we never + // get into trouble with cyclic PHIs here because we only consider + // instructions with a single use. + PHINode *PN = cast<PHINode>(I); + for (Value *IncValue : PN->incoming_values()) + if (!canEvaluateSExtd(IncValue, Ty)) return false; + return true; + } + default: + // TODO: Can handle more cases here. + break; + } + + return false; +} + +Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { + // If this sign extend is only used by a truncate, let the truncate be + // eliminated before we try to optimize this sext. + if (CI.hasOneUse() && isa<TruncInst>(CI.user_back())) + return nullptr; + + if (Instruction *I = commonCastTransforms(CI)) + return I; + + Value *Src = CI.getOperand(0); + Type *SrcTy = Src->getType(), *DestTy = CI.getType(); + unsigned SrcBitSize = SrcTy->getScalarSizeInBits(); + unsigned DestBitSize = DestTy->getScalarSizeInBits(); + + // If the value being extended is zero or positive, use a zext instead. + if (isKnownNonNegative(Src, DL, 0, &AC, &CI, &DT)) + return CastInst::Create(Instruction::ZExt, Src, DestTy); + + // Try to extend the entire expression tree to the wide destination type. + if (shouldChangeType(SrcTy, DestTy) && canEvaluateSExtd(Src, DestTy)) { + // Okay, we can transform this! Insert the new expression now. + LLVM_DEBUG( + dbgs() << "ICE: EvaluateInDifferentType converting expression type" + " to avoid sign extend: " + << CI << '\n'); + Value *Res = EvaluateInDifferentType(Src, DestTy, true); + assert(Res->getType() == DestTy); + + // If the high bits are already filled with sign bit, just replace this + // cast with the result. + if (ComputeNumSignBits(Res, 0, &CI) > DestBitSize - SrcBitSize) + return replaceInstUsesWith(CI, Res); + + // We need to emit a shl + ashr to do the sign extend. + Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize); + return BinaryOperator::CreateAShr(Builder.CreateShl(Res, ShAmt, "sext"), + ShAmt); + } + + Value *X; + if (match(Src, m_Trunc(m_Value(X)))) { + // If the input has more sign bits than bits truncated, then convert + // directly to final type. + unsigned XBitSize = X->getType()->getScalarSizeInBits(); + if (ComputeNumSignBits(X, 0, &CI) > XBitSize - SrcBitSize) + return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true); + + // If input is a trunc from the destination type, then convert into shifts. + if (Src->hasOneUse() && X->getType() == DestTy) { + // sext (trunc X) --> ashr (shl X, C), C + Constant *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize); + return BinaryOperator::CreateAShr(Builder.CreateShl(X, ShAmt), ShAmt); + } + + // If we are replacing shifted-in high zero bits with sign bits, convert + // the logic shift to arithmetic shift and eliminate the cast to + // intermediate type: + // sext (trunc (lshr Y, C)) --> sext/trunc (ashr Y, C) + Value *Y; + if (Src->hasOneUse() && + match(X, m_LShr(m_Value(Y), + m_SpecificIntAllowUndef(XBitSize - SrcBitSize)))) { + Value *Ashr = Builder.CreateAShr(Y, XBitSize - SrcBitSize); + return CastInst::CreateIntegerCast(Ashr, DestTy, /* isSigned */ true); + } + } + + if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src)) + return transformSExtICmp(ICI, CI); + + // If the input is a shl/ashr pair of a same constant, then this is a sign + // extension from a smaller value. If we could trust arbitrary bitwidth + // integers, we could turn this into a truncate to the smaller bit and then + // use a sext for the whole extension. Since we don't, look deeper and check + // for a truncate. If the source and dest are the same type, eliminate the + // trunc and extend and just do shifts. For example, turn: + // %a = trunc i32 %i to i8 + // %b = shl i8 %a, C + // %c = ashr i8 %b, C + // %d = sext i8 %c to i32 + // into: + // %a = shl i32 %i, 32-(8-C) + // %d = ashr i32 %a, 32-(8-C) + Value *A = nullptr; + // TODO: Eventually this could be subsumed by EvaluateInDifferentType. + Constant *BA = nullptr, *CA = nullptr; + if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_Constant(BA)), + m_Constant(CA))) && + BA->isElementWiseEqual(CA) && A->getType() == DestTy) { + Constant *WideCurrShAmt = ConstantExpr::getSExt(CA, DestTy); + Constant *NumLowbitsLeft = ConstantExpr::getSub( + ConstantInt::get(DestTy, SrcTy->getScalarSizeInBits()), WideCurrShAmt); + Constant *NewShAmt = ConstantExpr::getSub( + ConstantInt::get(DestTy, DestTy->getScalarSizeInBits()), + NumLowbitsLeft); + NewShAmt = + Constant::mergeUndefsWith(Constant::mergeUndefsWith(NewShAmt, BA), CA); + A = Builder.CreateShl(A, NewShAmt, CI.getName()); + return BinaryOperator::CreateAShr(A, NewShAmt); + } + + // Splatting a bit of constant-index across a value: + // sext (ashr (trunc iN X to iM), M-1) to iN --> ashr (shl X, N-M), N-1 + // If the dest type is different, use a cast (adjust use check). + if (match(Src, m_OneUse(m_AShr(m_Trunc(m_Value(X)), + m_SpecificInt(SrcBitSize - 1))))) { + Type *XTy = X->getType(); + unsigned XBitSize = XTy->getScalarSizeInBits(); + Constant *ShlAmtC = ConstantInt::get(XTy, XBitSize - SrcBitSize); + Constant *AshrAmtC = ConstantInt::get(XTy, XBitSize - 1); + if (XTy == DestTy) + return BinaryOperator::CreateAShr(Builder.CreateShl(X, ShlAmtC), + AshrAmtC); + if (cast<BinaryOperator>(Src)->getOperand(0)->hasOneUse()) { + Value *Ashr = Builder.CreateAShr(Builder.CreateShl(X, ShlAmtC), AshrAmtC); + return CastInst::CreateIntegerCast(Ashr, DestTy, /* isSigned */ true); + } + } + + if (match(Src, m_VScale(DL))) { + if (CI.getFunction() && + CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { + if (Log2_32(*MaxVScale) < (SrcBitSize - 1)) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(CI, VScale); + } + } + } + } + + return nullptr; +} + +/// Return a Constant* for the specified floating-point constant if it fits +/// in the specified FP type without changing its value. +static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { + bool losesInfo; + APFloat F = CFP->getValueAPF(); + (void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo); + return !losesInfo; +} + +static Type *shrinkFPConstant(ConstantFP *CFP) { + if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext())) + return nullptr; // No constant folding of this. + // See if the value can be truncated to half and then reextended. + if (fitsInFPType(CFP, APFloat::IEEEhalf())) + return Type::getHalfTy(CFP->getContext()); + // See if the value can be truncated to float and then reextended. + if (fitsInFPType(CFP, APFloat::IEEEsingle())) + return Type::getFloatTy(CFP->getContext()); + if (CFP->getType()->isDoubleTy()) + return nullptr; // Won't shrink. + if (fitsInFPType(CFP, APFloat::IEEEdouble())) + return Type::getDoubleTy(CFP->getContext()); + // Don't try to shrink to various long double types. + return nullptr; +} + +// Determine if this is a vector of ConstantFPs and if so, return the minimal +// type we can safely truncate all elements to. +// TODO: Make these support undef elements. +static Type *shrinkFPConstantVector(Value *V) { + auto *CV = dyn_cast<Constant>(V); + auto *CVVTy = dyn_cast<FixedVectorType>(V->getType()); + if (!CV || !CVVTy) + return nullptr; + + Type *MinType = nullptr; + + unsigned NumElts = CVVTy->getNumElements(); + + // For fixed-width vectors we find the minimal type by looking + // through the constant values of the vector. + for (unsigned i = 0; i != NumElts; ++i) { + auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i)); + if (!CFP) + return nullptr; + + Type *T = shrinkFPConstant(CFP); + if (!T) + return nullptr; + + // If we haven't found a type yet or this type has a larger mantissa than + // our previous type, this is our new minimal type. + if (!MinType || T->getFPMantissaWidth() > MinType->getFPMantissaWidth()) + MinType = T; + } + + // Make a vector type from the minimal type. + return FixedVectorType::get(MinType, NumElts); +} + +/// Find the minimum FP type we can safely truncate to. +static Type *getMinimumFPType(Value *V) { + if (auto *FPExt = dyn_cast<FPExtInst>(V)) + return FPExt->getOperand(0)->getType(); + + // If this value is a constant, return the constant in the smallest FP type + // that can accurately represent it. This allows us to turn + // (float)((double)X+2.0) into x+2.0f. + if (auto *CFP = dyn_cast<ConstantFP>(V)) + if (Type *T = shrinkFPConstant(CFP)) + return T; + + // We can only correctly find a minimum type for a scalable vector when it is + // a splat. For splats of constant values the fpext is wrapped up as a + // ConstantExpr. + if (auto *FPCExt = dyn_cast<ConstantExpr>(V)) + if (FPCExt->getOpcode() == Instruction::FPExt) + return FPCExt->getOperand(0)->getType(); + + // Try to shrink a vector of FP constants. This returns nullptr on scalable + // vectors + if (Type *T = shrinkFPConstantVector(V)) + return T; + + return V->getType(); +} + +/// Return true if the cast from integer to FP can be proven to be exact for all +/// possible inputs (the conversion does not lose any precision). +static bool isKnownExactCastIntToFP(CastInst &I, InstCombinerImpl &IC) { + CastInst::CastOps Opcode = I.getOpcode(); + assert((Opcode == CastInst::SIToFP || Opcode == CastInst::UIToFP) && + "Unexpected cast"); + Value *Src = I.getOperand(0); + Type *SrcTy = Src->getType(); + Type *FPTy = I.getType(); + bool IsSigned = Opcode == Instruction::SIToFP; + int SrcSize = (int)SrcTy->getScalarSizeInBits() - IsSigned; + + // Easy case - if the source integer type has less bits than the FP mantissa, + // then the cast must be exact. + int DestNumSigBits = FPTy->getFPMantissaWidth(); + if (SrcSize <= DestNumSigBits) + return true; + + // Cast from FP to integer and back to FP is independent of the intermediate + // integer width because of poison on overflow. + Value *F; + if (match(Src, m_FPToSI(m_Value(F))) || match(Src, m_FPToUI(m_Value(F)))) { + // If this is uitofp (fptosi F), the source needs an extra bit to avoid + // potential rounding of negative FP input values. + int SrcNumSigBits = F->getType()->getFPMantissaWidth(); + if (!IsSigned && match(Src, m_FPToSI(m_Value()))) + SrcNumSigBits++; + + // [su]itofp (fpto[su]i F) --> exact if the source type has less or equal + // significant bits than the destination (and make sure neither type is + // weird -- ppc_fp128). + if (SrcNumSigBits > 0 && DestNumSigBits > 0 && + SrcNumSigBits <= DestNumSigBits) + return true; + } + + // TODO: + // Try harder to find if the source integer type has less significant bits. + // For example, compute number of sign bits or compute low bit mask. + KnownBits SrcKnown = IC.computeKnownBits(Src, 0, &I); + int LowBits = + (int)SrcTy->getScalarSizeInBits() - SrcKnown.countMinLeadingZeros(); + if (LowBits <= DestNumSigBits) + return true; + + return false; +} + +Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { + if (Instruction *I = commonCastTransforms(FPT)) + return I; + + // If we have fptrunc(OpI (fpextend x), (fpextend y)), we would like to + // simplify this expression to avoid one or more of the trunc/extend + // operations if we can do so without changing the numerical results. + // + // The exact manner in which the widths of the operands interact to limit + // what we can and cannot do safely varies from operation to operation, and + // is explained below in the various case statements. + Type *Ty = FPT.getType(); + auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0)); + if (BO && BO->hasOneUse()) { + Type *LHSMinType = getMinimumFPType(BO->getOperand(0)); + Type *RHSMinType = getMinimumFPType(BO->getOperand(1)); + unsigned OpWidth = BO->getType()->getFPMantissaWidth(); + unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); + unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); + unsigned SrcWidth = std::max(LHSWidth, RHSWidth); + unsigned DstWidth = Ty->getFPMantissaWidth(); + switch (BO->getOpcode()) { + default: break; + case Instruction::FAdd: + case Instruction::FSub: + // For addition and subtraction, the infinitely precise result can + // essentially be arbitrarily wide; proving that double rounding + // will not occur because the result of OpI is exact (as we will for + // FMul, for example) is hopeless. However, we *can* nonetheless + // frequently know that double rounding cannot occur (or that it is + // innocuous) by taking advantage of the specific structure of + // infinitely-precise results that admit double rounding. + // + // Specifically, if OpWidth >= 2*DstWdith+1 and DstWidth is sufficient + // to represent both sources, we can guarantee that the double + // rounding is innocuous (See p50 of Figueroa's 2000 PhD thesis, + // "A Rigorous Framework for Fully Supporting the IEEE Standard ..." + // for proof of this fact). + // + // Note: Figueroa does not consider the case where DstFormat != + // SrcFormat. It's possible (likely even!) that this analysis + // could be tightened for those cases, but they are rare (the main + // case of interest here is (float)((double)float + float)). + if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) { + Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty); + Instruction *RI = BinaryOperator::Create(BO->getOpcode(), LHS, RHS); + RI->copyFastMathFlags(BO); + return RI; + } + break; + case Instruction::FMul: + // For multiplication, the infinitely precise result has at most + // LHSWidth + RHSWidth significant bits; if OpWidth is sufficient + // that such a value can be exactly represented, then no double + // rounding can possibly occur; we can safely perform the operation + // in the destination format if it can represent both sources. + if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) { + Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty); + return BinaryOperator::CreateFMulFMF(LHS, RHS, BO); + } + break; + case Instruction::FDiv: + // For division, we use again use the bound from Figueroa's + // dissertation. I am entirely certain that this bound can be + // tightened in the unbalanced operand case by an analysis based on + // the diophantine rational approximation bound, but the well-known + // condition used here is a good conservative first pass. + // TODO: Tighten bound via rigorous analysis of the unbalanced case. + if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) { + Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty); + return BinaryOperator::CreateFDivFMF(LHS, RHS, BO); + } + break; + case Instruction::FRem: { + // Remainder is straightforward. Remainder is always exact, so the + // type of OpI doesn't enter into things at all. We simply evaluate + // in whichever source type is larger, then convert to the + // destination type. + if (SrcWidth == OpWidth) + break; + Value *LHS, *RHS; + if (LHSWidth == SrcWidth) { + LHS = Builder.CreateFPTrunc(BO->getOperand(0), LHSMinType); + RHS = Builder.CreateFPTrunc(BO->getOperand(1), LHSMinType); + } else { + LHS = Builder.CreateFPTrunc(BO->getOperand(0), RHSMinType); + RHS = Builder.CreateFPTrunc(BO->getOperand(1), RHSMinType); + } + + Value *ExactResult = Builder.CreateFRemFMF(LHS, RHS, BO); + return CastInst::CreateFPCast(ExactResult, Ty); + } + } + } + + // (fptrunc (fneg x)) -> (fneg (fptrunc x)) + Value *X; + Instruction *Op = dyn_cast<Instruction>(FPT.getOperand(0)); + if (Op && Op->hasOneUse()) { + // FIXME: The FMF should propagate from the fptrunc, not the source op. + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + if (isa<FPMathOperator>(Op)) + Builder.setFastMathFlags(Op->getFastMathFlags()); + + if (match(Op, m_FNeg(m_Value(X)))) { + Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty); + + return UnaryOperator::CreateFNegFMF(InnerTrunc, Op); + } + + // If we are truncating a select that has an extended operand, we can + // narrow the other operand and do the select as a narrow op. + Value *Cond, *X, *Y; + if (match(Op, m_Select(m_Value(Cond), m_FPExt(m_Value(X)), m_Value(Y))) && + X->getType() == Ty) { + // fptrunc (select Cond, (fpext X), Y --> select Cond, X, (fptrunc Y) + Value *NarrowY = Builder.CreateFPTrunc(Y, Ty); + Value *Sel = Builder.CreateSelect(Cond, X, NarrowY, "narrow.sel", Op); + return replaceInstUsesWith(FPT, Sel); + } + if (match(Op, m_Select(m_Value(Cond), m_Value(Y), m_FPExt(m_Value(X)))) && + X->getType() == Ty) { + // fptrunc (select Cond, Y, (fpext X) --> select Cond, (fptrunc Y), X + Value *NarrowY = Builder.CreateFPTrunc(Y, Ty); + Value *Sel = Builder.CreateSelect(Cond, NarrowY, X, "narrow.sel", Op); + return replaceInstUsesWith(FPT, Sel); + } + } + + if (auto *II = dyn_cast<IntrinsicInst>(FPT.getOperand(0))) { + switch (II->getIntrinsicID()) { + default: break; + case Intrinsic::ceil: + case Intrinsic::fabs: + case Intrinsic::floor: + case Intrinsic::nearbyint: + case Intrinsic::rint: + case Intrinsic::round: + case Intrinsic::roundeven: + case Intrinsic::trunc: { + Value *Src = II->getArgOperand(0); + if (!Src->hasOneUse()) + break; + + // Except for fabs, this transformation requires the input of the unary FP + // operation to be itself an fpext from the type to which we're + // truncating. + if (II->getIntrinsicID() != Intrinsic::fabs) { + FPExtInst *FPExtSrc = dyn_cast<FPExtInst>(Src); + if (!FPExtSrc || FPExtSrc->getSrcTy() != Ty) + break; + } + + // Do unary FP operation on smaller type. + // (fptrunc (fabs x)) -> (fabs (fptrunc x)) + Value *InnerTrunc = Builder.CreateFPTrunc(Src, Ty); + Function *Overload = Intrinsic::getDeclaration(FPT.getModule(), + II->getIntrinsicID(), Ty); + SmallVector<OperandBundleDef, 1> OpBundles; + II->getOperandBundlesAsDefs(OpBundles); + CallInst *NewCI = + CallInst::Create(Overload, {InnerTrunc}, OpBundles, II->getName()); + NewCI->copyFastMathFlags(II); + return NewCI; + } + } + } + + if (Instruction *I = shrinkInsertElt(FPT, Builder)) + return I; + + Value *Src = FPT.getOperand(0); + if (isa<SIToFPInst>(Src) || isa<UIToFPInst>(Src)) { + auto *FPCast = cast<CastInst>(Src); + if (isKnownExactCastIntToFP(*FPCast, *this)) + return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty); + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) { + // If the source operand is a cast from integer to FP and known exact, then + // cast the integer operand directly to the destination type. + Type *Ty = FPExt.getType(); + Value *Src = FPExt.getOperand(0); + if (isa<SIToFPInst>(Src) || isa<UIToFPInst>(Src)) { + auto *FPCast = cast<CastInst>(Src); + if (isKnownExactCastIntToFP(*FPCast, *this)) + return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty); + } + + return commonCastTransforms(FPExt); +} + +/// fpto{s/u}i({u/s}itofp(X)) --> X or zext(X) or sext(X) or trunc(X) +/// This is safe if the intermediate type has enough bits in its mantissa to +/// accurately represent all values of X. For example, this won't work with +/// i64 -> float -> i64. +Instruction *InstCombinerImpl::foldItoFPtoI(CastInst &FI) { + if (!isa<UIToFPInst>(FI.getOperand(0)) && !isa<SIToFPInst>(FI.getOperand(0))) + return nullptr; + + auto *OpI = cast<CastInst>(FI.getOperand(0)); + Value *X = OpI->getOperand(0); + Type *XType = X->getType(); + Type *DestType = FI.getType(); + bool IsOutputSigned = isa<FPToSIInst>(FI); + + // Since we can assume the conversion won't overflow, our decision as to + // whether the input will fit in the float should depend on the minimum + // of the input range and output range. + + // This means this is also safe for a signed input and unsigned output, since + // a negative input would lead to undefined behavior. + if (!isKnownExactCastIntToFP(*OpI, *this)) { + // The first cast may not round exactly based on the source integer width + // and FP width, but the overflow UB rules can still allow this to fold. + // If the destination type is narrow, that means the intermediate FP value + // must be large enough to hold the source value exactly. + // For example, (uint8_t)((float)(uint32_t 16777217) is undefined behavior. + int OutputSize = (int)DestType->getScalarSizeInBits(); + if (OutputSize > OpI->getType()->getFPMantissaWidth()) + return nullptr; + } + + if (DestType->getScalarSizeInBits() > XType->getScalarSizeInBits()) { + bool IsInputSigned = isa<SIToFPInst>(OpI); + if (IsInputSigned && IsOutputSigned) + return new SExtInst(X, DestType); + return new ZExtInst(X, DestType); + } + if (DestType->getScalarSizeInBits() < XType->getScalarSizeInBits()) + return new TruncInst(X, DestType); + + assert(XType == DestType && "Unexpected types for int to FP to int casts"); + return replaceInstUsesWith(FI, X); +} + +Instruction *InstCombinerImpl::visitFPToUI(FPToUIInst &FI) { + if (Instruction *I = foldItoFPtoI(FI)) + return I; + + return commonCastTransforms(FI); +} + +Instruction *InstCombinerImpl::visitFPToSI(FPToSIInst &FI) { + if (Instruction *I = foldItoFPtoI(FI)) + return I; + + return commonCastTransforms(FI); +} + +Instruction *InstCombinerImpl::visitUIToFP(CastInst &CI) { + return commonCastTransforms(CI); +} + +Instruction *InstCombinerImpl::visitSIToFP(CastInst &CI) { + return commonCastTransforms(CI); +} + +Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) { + // If the source integer type is not the intptr_t type for this target, do a + // trunc or zext to the intptr_t type, then inttoptr of it. This allows the + // cast to be exposed to other transforms. + unsigned AS = CI.getAddressSpace(); + if (CI.getOperand(0)->getType()->getScalarSizeInBits() != + DL.getPointerSizeInBits(AS)) { + Type *Ty = CI.getOperand(0)->getType()->getWithNewType( + DL.getIntPtrType(CI.getContext(), AS)); + Value *P = Builder.CreateZExtOrTrunc(CI.getOperand(0), Ty); + return new IntToPtrInst(P, CI.getType()); + } + + if (Instruction *I = commonCastTransforms(CI)) + return I; + + return nullptr; +} + +/// Implement the transforms for cast of pointer (bitcast/ptrtoint) +Instruction *InstCombinerImpl::commonPointerCastTransforms(CastInst &CI) { + Value *Src = CI.getOperand(0); + + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Src)) { + // If casting the result of a getelementptr instruction with no offset, turn + // this into a cast of the original pointer! + if (GEP->hasAllZeroIndices() && + // If CI is an addrspacecast and GEP changes the poiner type, merging + // GEP into CI would undo canonicalizing addrspacecast with different + // pointer types, causing infinite loops. + (!isa<AddrSpaceCastInst>(CI) || + GEP->getType() == GEP->getPointerOperandType())) { + // Changing the cast operand is usually not a good idea but it is safe + // here because the pointer operand is being replaced with another + // pointer operand so the opcode doesn't need to change. + return replaceOperand(CI, 0, GEP->getOperand(0)); + } + } + + return commonCastTransforms(CI); +} + +Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { + // If the destination integer type is not the intptr_t type for this target, + // do a ptrtoint to intptr_t then do a trunc or zext. This allows the cast + // to be exposed to other transforms. + Value *SrcOp = CI.getPointerOperand(); + Type *SrcTy = SrcOp->getType(); + Type *Ty = CI.getType(); + unsigned AS = CI.getPointerAddressSpace(); + unsigned TySize = Ty->getScalarSizeInBits(); + unsigned PtrSize = DL.getPointerSizeInBits(AS); + if (TySize != PtrSize) { + Type *IntPtrTy = + SrcTy->getWithNewType(DL.getIntPtrType(CI.getContext(), AS)); + Value *P = Builder.CreatePtrToInt(SrcOp, IntPtrTy); + return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false); + } + + if (auto *GEP = dyn_cast<GetElementPtrInst>(SrcOp)) { + // Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use. + // While this can increase the number of instructions it doesn't actually + // increase the overall complexity since the arithmetic is just part of + // the GEP otherwise. + if (GEP->hasOneUse() && + isa<ConstantPointerNull>(GEP->getPointerOperand())) { + return replaceInstUsesWith(CI, + Builder.CreateIntCast(EmitGEPOffset(GEP), Ty, + /*isSigned=*/false)); + } + } + + Value *Vec, *Scalar, *Index; + if (match(SrcOp, m_OneUse(m_InsertElt(m_IntToPtr(m_Value(Vec)), + m_Value(Scalar), m_Value(Index)))) && + Vec->getType() == Ty) { + assert(Vec->getType()->getScalarSizeInBits() == PtrSize && "Wrong type"); + // Convert the scalar to int followed by insert to eliminate one cast: + // p2i (ins (i2p Vec), Scalar, Index --> ins Vec, (p2i Scalar), Index + Value *NewCast = Builder.CreatePtrToInt(Scalar, Ty->getScalarType()); + return InsertElementInst::Create(Vec, NewCast, Index); + } + + return commonPointerCastTransforms(CI); +} + +/// This input value (which is known to have vector type) is being zero extended +/// or truncated to the specified vector type. Since the zext/trunc is done +/// using an integer type, we have a (bitcast(cast(bitcast))) pattern, +/// endianness will impact which end of the vector that is extended or +/// truncated. +/// +/// A vector is always stored with index 0 at the lowest address, which +/// corresponds to the most significant bits for a big endian stored integer and +/// the least significant bits for little endian. A trunc/zext of an integer +/// impacts the big end of the integer. Thus, we need to add/remove elements at +/// the front of the vector for big endian targets, and the back of the vector +/// for little endian targets. +/// +/// Try to replace it with a shuffle (and vector/vector bitcast) if possible. +/// +/// The source and destination vector types may have different element types. +static Instruction * +optimizeVectorResizeWithIntegerBitCasts(Value *InVal, VectorType *DestTy, + InstCombinerImpl &IC) { + // We can only do this optimization if the output is a multiple of the input + // element size, or the input is a multiple of the output element size. + // Convert the input type to have the same element type as the output. + VectorType *SrcTy = cast<VectorType>(InVal->getType()); + + if (SrcTy->getElementType() != DestTy->getElementType()) { + // The input types don't need to be identical, but for now they must be the + // same size. There is no specific reason we couldn't handle things like + // <4 x i16> -> <4 x i32> by bitcasting to <2 x i32> but haven't gotten + // there yet. + if (SrcTy->getElementType()->getPrimitiveSizeInBits() != + DestTy->getElementType()->getPrimitiveSizeInBits()) + return nullptr; + + SrcTy = + FixedVectorType::get(DestTy->getElementType(), + cast<FixedVectorType>(SrcTy)->getNumElements()); + InVal = IC.Builder.CreateBitCast(InVal, SrcTy); + } + + bool IsBigEndian = IC.getDataLayout().isBigEndian(); + unsigned SrcElts = cast<FixedVectorType>(SrcTy)->getNumElements(); + unsigned DestElts = cast<FixedVectorType>(DestTy)->getNumElements(); + + assert(SrcElts != DestElts && "Element counts should be different."); + + // Now that the element types match, get the shuffle mask and RHS of the + // shuffle to use, which depends on whether we're increasing or decreasing the + // size of the input. + auto ShuffleMaskStorage = llvm::to_vector<16>(llvm::seq<int>(0, SrcElts)); + ArrayRef<int> ShuffleMask; + Value *V2; + + if (SrcElts > DestElts) { + // If we're shrinking the number of elements (rewriting an integer + // truncate), just shuffle in the elements corresponding to the least + // significant bits from the input and use poison as the second shuffle + // input. + V2 = PoisonValue::get(SrcTy); + // Make sure the shuffle mask selects the "least significant bits" by + // keeping elements from back of the src vector for big endian, and from the + // front for little endian. + ShuffleMask = ShuffleMaskStorage; + if (IsBigEndian) + ShuffleMask = ShuffleMask.take_back(DestElts); + else + ShuffleMask = ShuffleMask.take_front(DestElts); + } else { + // If we're increasing the number of elements (rewriting an integer zext), + // shuffle in all of the elements from InVal. Fill the rest of the result + // elements with zeros from a constant zero. + V2 = Constant::getNullValue(SrcTy); + // Use first elt from V2 when indicating zero in the shuffle mask. + uint32_t NullElt = SrcElts; + // Extend with null values in the "most significant bits" by adding elements + // in front of the src vector for big endian, and at the back for little + // endian. + unsigned DeltaElts = DestElts - SrcElts; + if (IsBigEndian) + ShuffleMaskStorage.insert(ShuffleMaskStorage.begin(), DeltaElts, NullElt); + else + ShuffleMaskStorage.append(DeltaElts, NullElt); + ShuffleMask = ShuffleMaskStorage; + } + + return new ShuffleVectorInst(InVal, V2, ShuffleMask); +} + +static bool isMultipleOfTypeSize(unsigned Value, Type *Ty) { + return Value % Ty->getPrimitiveSizeInBits() == 0; +} + +static unsigned getTypeSizeIndex(unsigned Value, Type *Ty) { + return Value / Ty->getPrimitiveSizeInBits(); +} + +/// V is a value which is inserted into a vector of VecEltTy. +/// Look through the value to see if we can decompose it into +/// insertions into the vector. See the example in the comment for +/// OptimizeIntegerToVectorInsertions for the pattern this handles. +/// The type of V is always a non-zero multiple of VecEltTy's size. +/// Shift is the number of bits between the lsb of V and the lsb of +/// the vector. +/// +/// This returns false if the pattern can't be matched or true if it can, +/// filling in Elements with the elements found here. +static bool collectInsertionElements(Value *V, unsigned Shift, + SmallVectorImpl<Value *> &Elements, + Type *VecEltTy, bool isBigEndian) { + assert(isMultipleOfTypeSize(Shift, VecEltTy) && + "Shift should be a multiple of the element type size"); + + // Undef values never contribute useful bits to the result. + if (isa<UndefValue>(V)) return true; + + // If we got down to a value of the right type, we win, try inserting into the + // right element. + if (V->getType() == VecEltTy) { + // Inserting null doesn't actually insert any elements. + if (Constant *C = dyn_cast<Constant>(V)) + if (C->isNullValue()) + return true; + + unsigned ElementIndex = getTypeSizeIndex(Shift, VecEltTy); + if (isBigEndian) + ElementIndex = Elements.size() - ElementIndex - 1; + + // Fail if multiple elements are inserted into this slot. + if (Elements[ElementIndex]) + return false; + + Elements[ElementIndex] = V; + return true; + } + + if (Constant *C = dyn_cast<Constant>(V)) { + // Figure out the # elements this provides, and bitcast it or slice it up + // as required. + unsigned NumElts = getTypeSizeIndex(C->getType()->getPrimitiveSizeInBits(), + VecEltTy); + // If the constant is the size of a vector element, we just need to bitcast + // it to the right type so it gets properly inserted. + if (NumElts == 1) + return collectInsertionElements(ConstantExpr::getBitCast(C, VecEltTy), + Shift, Elements, VecEltTy, isBigEndian); + + // Okay, this is a constant that covers multiple elements. Slice it up into + // pieces and insert each element-sized piece into the vector. + if (!isa<IntegerType>(C->getType())) + C = ConstantExpr::getBitCast(C, IntegerType::get(V->getContext(), + C->getType()->getPrimitiveSizeInBits())); + unsigned ElementSize = VecEltTy->getPrimitiveSizeInBits(); + Type *ElementIntTy = IntegerType::get(C->getContext(), ElementSize); + + for (unsigned i = 0; i != NumElts; ++i) { + unsigned ShiftI = Shift+i*ElementSize; + Constant *Piece = ConstantExpr::getLShr(C, ConstantInt::get(C->getType(), + ShiftI)); + Piece = ConstantExpr::getTrunc(Piece, ElementIntTy); + if (!collectInsertionElements(Piece, ShiftI, Elements, VecEltTy, + isBigEndian)) + return false; + } + return true; + } + + if (!V->hasOneUse()) return false; + + Instruction *I = dyn_cast<Instruction>(V); + if (!I) return false; + switch (I->getOpcode()) { + default: return false; // Unhandled case. + case Instruction::BitCast: + if (I->getOperand(0)->getType()->isVectorTy()) + return false; + return collectInsertionElements(I->getOperand(0), Shift, Elements, VecEltTy, + isBigEndian); + case Instruction::ZExt: + if (!isMultipleOfTypeSize( + I->getOperand(0)->getType()->getPrimitiveSizeInBits(), + VecEltTy)) + return false; + return collectInsertionElements(I->getOperand(0), Shift, Elements, VecEltTy, + isBigEndian); + case Instruction::Or: + return collectInsertionElements(I->getOperand(0), Shift, Elements, VecEltTy, + isBigEndian) && + collectInsertionElements(I->getOperand(1), Shift, Elements, VecEltTy, + isBigEndian); + case Instruction::Shl: { + // Must be shifting by a constant that is a multiple of the element size. + ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1)); + if (!CI) return false; + Shift += CI->getZExtValue(); + if (!isMultipleOfTypeSize(Shift, VecEltTy)) return false; + return collectInsertionElements(I->getOperand(0), Shift, Elements, VecEltTy, + isBigEndian); + } + + } +} + + +/// If the input is an 'or' instruction, we may be doing shifts and ors to +/// assemble the elements of the vector manually. +/// Try to rip the code out and replace it with insertelements. This is to +/// optimize code like this: +/// +/// %tmp37 = bitcast float %inc to i32 +/// %tmp38 = zext i32 %tmp37 to i64 +/// %tmp31 = bitcast float %inc5 to i32 +/// %tmp32 = zext i32 %tmp31 to i64 +/// %tmp33 = shl i64 %tmp32, 32 +/// %ins35 = or i64 %tmp33, %tmp38 +/// %tmp43 = bitcast i64 %ins35 to <2 x float> +/// +/// Into two insertelements that do "buildvector{%inc, %inc5}". +static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI, + InstCombinerImpl &IC) { + auto *DestVecTy = cast<FixedVectorType>(CI.getType()); + Value *IntInput = CI.getOperand(0); + + SmallVector<Value*, 8> Elements(DestVecTy->getNumElements()); + if (!collectInsertionElements(IntInput, 0, Elements, + DestVecTy->getElementType(), + IC.getDataLayout().isBigEndian())) + return nullptr; + + // If we succeeded, we know that all of the element are specified by Elements + // or are zero if Elements has a null entry. Recast this as a set of + // insertions. + Value *Result = Constant::getNullValue(CI.getType()); + for (unsigned i = 0, e = Elements.size(); i != e; ++i) { + if (!Elements[i]) continue; // Unset element. + + Result = IC.Builder.CreateInsertElement(Result, Elements[i], + IC.Builder.getInt32(i)); + } + + return Result; +} + +/// Canonicalize scalar bitcasts of extracted elements into a bitcast of the +/// vector followed by extract element. The backend tends to handle bitcasts of +/// vectors better than bitcasts of scalars because vector registers are +/// usually not type-specific like scalar integer or scalar floating-point. +static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast, + InstCombinerImpl &IC) { + Value *VecOp, *Index; + if (!match(BitCast.getOperand(0), + m_OneUse(m_ExtractElt(m_Value(VecOp), m_Value(Index))))) + return nullptr; + + // The bitcast must be to a vectorizable type, otherwise we can't make a new + // type to extract from. + Type *DestType = BitCast.getType(); + VectorType *VecType = cast<VectorType>(VecOp->getType()); + if (VectorType::isValidElementType(DestType)) { + auto *NewVecType = VectorType::get(DestType, VecType); + auto *NewBC = IC.Builder.CreateBitCast(VecOp, NewVecType, "bc"); + return ExtractElementInst::Create(NewBC, Index); + } + + // Only solve DestType is vector to avoid inverse transform in visitBitCast. + // bitcast (extractelement <1 x elt>, dest) -> bitcast(<1 x elt>, dest) + auto *FixedVType = dyn_cast<FixedVectorType>(VecType); + if (DestType->isVectorTy() && FixedVType && FixedVType->getNumElements() == 1) + return CastInst::Create(Instruction::BitCast, VecOp, DestType); + + return nullptr; +} + +/// Change the type of a bitwise logic operation if we can eliminate a bitcast. +static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, + InstCombiner::BuilderTy &Builder) { + Type *DestTy = BitCast.getType(); + BinaryOperator *BO; + + if (!match(BitCast.getOperand(0), m_OneUse(m_BinOp(BO))) || + !BO->isBitwiseLogicOp()) + return nullptr; + + // FIXME: This transform is restricted to vector types to avoid backend + // problems caused by creating potentially illegal operations. If a fix-up is + // added to handle that situation, we can remove this check. + if (!DestTy->isVectorTy() || !BO->getType()->isVectorTy()) + return nullptr; + + if (DestTy->isFPOrFPVectorTy()) { + Value *X, *Y; + // bitcast(logic(bitcast(X), bitcast(Y))) -> bitcast'(logic(bitcast'(X), Y)) + if (match(BO->getOperand(0), m_OneUse(m_BitCast(m_Value(X)))) && + match(BO->getOperand(1), m_OneUse(m_BitCast(m_Value(Y))))) { + if (X->getType()->isFPOrFPVectorTy() && + Y->getType()->isIntOrIntVectorTy()) { + Value *CastedOp = + Builder.CreateBitCast(BO->getOperand(0), Y->getType()); + Value *NewBO = Builder.CreateBinOp(BO->getOpcode(), CastedOp, Y); + return CastInst::CreateBitOrPointerCast(NewBO, DestTy); + } + if (X->getType()->isIntOrIntVectorTy() && + Y->getType()->isFPOrFPVectorTy()) { + Value *CastedOp = + Builder.CreateBitCast(BO->getOperand(1), X->getType()); + Value *NewBO = Builder.CreateBinOp(BO->getOpcode(), CastedOp, X); + return CastInst::CreateBitOrPointerCast(NewBO, DestTy); + } + } + return nullptr; + } + + if (!DestTy->isIntOrIntVectorTy()) + return nullptr; + + Value *X; + if (match(BO->getOperand(0), m_OneUse(m_BitCast(m_Value(X)))) && + X->getType() == DestTy && !isa<Constant>(X)) { + // bitcast(logic(bitcast(X), Y)) --> logic'(X, bitcast(Y)) + Value *CastedOp1 = Builder.CreateBitCast(BO->getOperand(1), DestTy); + return BinaryOperator::Create(BO->getOpcode(), X, CastedOp1); + } + + if (match(BO->getOperand(1), m_OneUse(m_BitCast(m_Value(X)))) && + X->getType() == DestTy && !isa<Constant>(X)) { + // bitcast(logic(Y, bitcast(X))) --> logic'(bitcast(Y), X) + Value *CastedOp0 = Builder.CreateBitCast(BO->getOperand(0), DestTy); + return BinaryOperator::Create(BO->getOpcode(), CastedOp0, X); + } + + // Canonicalize vector bitcasts to come before vector bitwise logic with a + // constant. This eases recognition of special constants for later ops. + // Example: + // icmp u/s (a ^ signmask), (b ^ signmask) --> icmp s/u a, b + Constant *C; + if (match(BO->getOperand(1), m_Constant(C))) { + // bitcast (logic X, C) --> logic (bitcast X, C') + Value *CastedOp0 = Builder.CreateBitCast(BO->getOperand(0), DestTy); + Value *CastedC = Builder.CreateBitCast(C, DestTy); + return BinaryOperator::Create(BO->getOpcode(), CastedOp0, CastedC); + } + + return nullptr; +} + +/// Change the type of a select if we can eliminate a bitcast. +static Instruction *foldBitCastSelect(BitCastInst &BitCast, + InstCombiner::BuilderTy &Builder) { + Value *Cond, *TVal, *FVal; + if (!match(BitCast.getOperand(0), + m_OneUse(m_Select(m_Value(Cond), m_Value(TVal), m_Value(FVal))))) + return nullptr; + + // A vector select must maintain the same number of elements in its operands. + Type *CondTy = Cond->getType(); + Type *DestTy = BitCast.getType(); + if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) + if (!DestTy->isVectorTy() || + CondVTy->getElementCount() != + cast<VectorType>(DestTy)->getElementCount()) + return nullptr; + + // FIXME: This transform is restricted from changing the select between + // scalars and vectors to avoid backend problems caused by creating + // potentially illegal operations. If a fix-up is added to handle that + // situation, we can remove this check. + if (DestTy->isVectorTy() != TVal->getType()->isVectorTy()) + return nullptr; + + auto *Sel = cast<Instruction>(BitCast.getOperand(0)); + Value *X; + if (match(TVal, m_OneUse(m_BitCast(m_Value(X)))) && X->getType() == DestTy && + !isa<Constant>(X)) { + // bitcast(select(Cond, bitcast(X), Y)) --> select'(Cond, X, bitcast(Y)) + Value *CastedVal = Builder.CreateBitCast(FVal, DestTy); + return SelectInst::Create(Cond, X, CastedVal, "", nullptr, Sel); + } + + if (match(FVal, m_OneUse(m_BitCast(m_Value(X)))) && X->getType() == DestTy && + !isa<Constant>(X)) { + // bitcast(select(Cond, Y, bitcast(X))) --> select'(Cond, bitcast(Y), X) + Value *CastedVal = Builder.CreateBitCast(TVal, DestTy); + return SelectInst::Create(Cond, CastedVal, X, "", nullptr, Sel); + } + + return nullptr; +} + +/// Check if all users of CI are StoreInsts. +static bool hasStoreUsersOnly(CastInst &CI) { + for (User *U : CI.users()) { + if (!isa<StoreInst>(U)) + return false; + } + return true; +} + +/// This function handles following case +/// +/// A -> B cast +/// PHI +/// B -> A cast +/// +/// All the related PHI nodes can be replaced by new PHI nodes with type A. +/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. +Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI, + PHINode *PN) { + // BitCast used by Store can be handled in InstCombineLoadStoreAlloca.cpp. + if (hasStoreUsersOnly(CI)) + return nullptr; + + Value *Src = CI.getOperand(0); + Type *SrcTy = Src->getType(); // Type B + Type *DestTy = CI.getType(); // Type A + + SmallVector<PHINode *, 4> PhiWorklist; + SmallSetVector<PHINode *, 4> OldPhiNodes; + + // Find all of the A->B casts and PHI nodes. + // We need to inspect all related PHI nodes, but PHIs can be cyclic, so + // OldPhiNodes is used to track all known PHI nodes, before adding a new + // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. + PhiWorklist.push_back(PN); + OldPhiNodes.insert(PN); + while (!PhiWorklist.empty()) { + auto *OldPN = PhiWorklist.pop_back_val(); + for (Value *IncValue : OldPN->incoming_values()) { + if (isa<Constant>(IncValue)) + continue; + + if (auto *LI = dyn_cast<LoadInst>(IncValue)) { + // If there is a sequence of one or more load instructions, each loaded + // value is used as address of later load instruction, bitcast is + // necessary to change the value type, don't optimize it. For + // simplicity we give up if the load address comes from another load. + Value *Addr = LI->getOperand(0); + if (Addr == &CI || isa<LoadInst>(Addr)) + return nullptr; + // Don't tranform "load <256 x i32>, <256 x i32>*" to + // "load x86_amx, x86_amx*", because x86_amx* is invalid. + // TODO: Remove this check when bitcast between vector and x86_amx + // is replaced with a specific intrinsic. + if (DestTy->isX86_AMXTy()) + return nullptr; + if (LI->hasOneUse() && LI->isSimple()) + continue; + // If a LoadInst has more than one use, changing the type of loaded + // value may create another bitcast. + return nullptr; + } + + if (auto *PNode = dyn_cast<PHINode>(IncValue)) { + if (OldPhiNodes.insert(PNode)) + PhiWorklist.push_back(PNode); + continue; + } + + auto *BCI = dyn_cast<BitCastInst>(IncValue); + // We can't handle other instructions. + if (!BCI) + return nullptr; + + // Verify it's a A->B cast. + Type *TyA = BCI->getOperand(0)->getType(); + Type *TyB = BCI->getType(); + if (TyA != DestTy || TyB != SrcTy) + return nullptr; + } + } + + // Check that each user of each old PHI node is something that we can + // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. + for (auto *OldPN : OldPhiNodes) { + for (User *V : OldPN->users()) { + if (auto *SI = dyn_cast<StoreInst>(V)) { + if (!SI->isSimple() || SI->getOperand(0) != OldPN) + return nullptr; + } else if (auto *BCI = dyn_cast<BitCastInst>(V)) { + // Verify it's a B->A cast. + Type *TyB = BCI->getOperand(0)->getType(); + Type *TyA = BCI->getType(); + if (TyA != DestTy || TyB != SrcTy) + return nullptr; + } else if (auto *PHI = dyn_cast<PHINode>(V)) { + // As long as the user is another old PHI node, then even if we don't + // rewrite it, the PHI web we're considering won't have any users + // outside itself, so it'll be dead. + if (!OldPhiNodes.contains(PHI)) + return nullptr; + } else { + return nullptr; + } + } + } + + // For each old PHI node, create a corresponding new PHI node with a type A. + SmallDenseMap<PHINode *, PHINode *> NewPNodes; + for (auto *OldPN : OldPhiNodes) { + Builder.SetInsertPoint(OldPN); + PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands()); + NewPNodes[OldPN] = NewPN; + } + + // Fill in the operands of new PHI nodes. + for (auto *OldPN : OldPhiNodes) { + PHINode *NewPN = NewPNodes[OldPN]; + for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { + Value *V = OldPN->getOperand(j); + Value *NewV = nullptr; + if (auto *C = dyn_cast<Constant>(V)) { + NewV = ConstantExpr::getBitCast(C, DestTy); + } else if (auto *LI = dyn_cast<LoadInst>(V)) { + // Explicitly perform load combine to make sure no opposing transform + // can remove the bitcast in the meantime and trigger an infinite loop. + Builder.SetInsertPoint(LI); + NewV = combineLoadToNewType(*LI, DestTy); + // Remove the old load and its use in the old phi, which itself becomes + // dead once the whole transform finishes. + replaceInstUsesWith(*LI, PoisonValue::get(LI->getType())); + eraseInstFromFunction(*LI); + } else if (auto *BCI = dyn_cast<BitCastInst>(V)) { + NewV = BCI->getOperand(0); + } else if (auto *PrevPN = dyn_cast<PHINode>(V)) { + NewV = NewPNodes[PrevPN]; + } + assert(NewV); + NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); + } + } + + // Traverse all accumulated PHI nodes and process its users, + // which are Stores and BitcCasts. Without this processing + // NewPHI nodes could be replicated and could lead to extra + // moves generated after DeSSA. + // If there is a store with type B, change it to type A. + + + // Replace users of BitCast B->A with NewPHI. These will help + // later to get rid off a closure formed by OldPHI nodes. + Instruction *RetVal = nullptr; + for (auto *OldPN : OldPhiNodes) { + PHINode *NewPN = NewPNodes[OldPN]; + for (User *V : make_early_inc_range(OldPN->users())) { + if (auto *SI = dyn_cast<StoreInst>(V)) { + assert(SI->isSimple() && SI->getOperand(0) == OldPN); + Builder.SetInsertPoint(SI); + auto *NewBC = + cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy)); + SI->setOperand(0, NewBC); + Worklist.push(SI); + assert(hasStoreUsersOnly(*NewBC)); + } + else if (auto *BCI = dyn_cast<BitCastInst>(V)) { + Type *TyB = BCI->getOperand(0)->getType(); + Type *TyA = BCI->getType(); + assert(TyA == DestTy && TyB == SrcTy); + (void) TyA; + (void) TyB; + Instruction *I = replaceInstUsesWith(*BCI, NewPN); + if (BCI == &CI) + RetVal = I; + } else if (auto *PHI = dyn_cast<PHINode>(V)) { + assert(OldPhiNodes.contains(PHI)); + (void) PHI; + } else { + llvm_unreachable("all uses should be handled"); + } + } + } + + return RetVal; +} + +static Instruction *convertBitCastToGEP(BitCastInst &CI, IRBuilderBase &Builder, + const DataLayout &DL) { + Value *Src = CI.getOperand(0); + PointerType *SrcPTy = cast<PointerType>(Src->getType()); + PointerType *DstPTy = cast<PointerType>(CI.getType()); + + // Bitcasts involving opaque pointers cannot be converted into a GEP. + if (SrcPTy->isOpaque() || DstPTy->isOpaque()) + return nullptr; + + Type *DstElTy = DstPTy->getNonOpaquePointerElementType(); + Type *SrcElTy = SrcPTy->getNonOpaquePointerElementType(); + + // When the type pointed to is not sized the cast cannot be + // turned into a gep. + if (!SrcElTy->isSized()) + return nullptr; + + // If the source and destination are pointers, and this cast is equivalent + // to a getelementptr X, 0, 0, 0... turn it into the appropriate gep. + // This can enhance SROA and other transforms that want type-safe pointers. + unsigned NumZeros = 0; + while (SrcElTy && SrcElTy != DstElTy) { + SrcElTy = GetElementPtrInst::getTypeAtIndex(SrcElTy, (uint64_t)0); + ++NumZeros; + } + + // If we found a path from the src to dest, create the getelementptr now. + if (SrcElTy == DstElTy) { + SmallVector<Value *, 8> Idxs(NumZeros + 1, Builder.getInt32(0)); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + SrcPTy->getNonOpaquePointerElementType(), Src, Idxs); + + // If the source pointer is dereferenceable, then assume it points to an + // allocated object and apply "inbounds" to the GEP. + bool CanBeNull, CanBeFreed; + if (Src->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed)) { + // In a non-default address space (not 0), a null pointer can not be + // assumed inbounds, so ignore that case (dereferenceable_or_null). + // The reason is that 'null' is not treated differently in these address + // spaces, and we consequently ignore the 'gep inbounds' special case + // for 'null' which allows 'inbounds' on 'null' if the indices are + // zeros. + if (SrcPTy->getAddressSpace() == 0 || !CanBeNull) + GEP->setIsInBounds(); + } + return GEP; + } + return nullptr; +} + +Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { + // If the operands are integer typed then apply the integer transforms, + // otherwise just apply the common ones. + Value *Src = CI.getOperand(0); + Type *SrcTy = Src->getType(); + Type *DestTy = CI.getType(); + + // Get rid of casts from one type to the same type. These are useless and can + // be replaced by the operand. + if (DestTy == Src->getType()) + return replaceInstUsesWith(CI, Src); + + if (isa<PointerType>(SrcTy) && isa<PointerType>(DestTy)) { + // If we are casting a alloca to a pointer to a type of the same + // size, rewrite the allocation instruction to allocate the "right" type. + // There is no need to modify malloc calls because it is their bitcast that + // needs to be cleaned up. + if (AllocaInst *AI = dyn_cast<AllocaInst>(Src)) + if (Instruction *V = PromoteCastOfAllocation(CI, *AI)) + return V; + + if (Instruction *I = convertBitCastToGEP(CI, Builder, DL)) + return I; + } + + if (FixedVectorType *DestVTy = dyn_cast<FixedVectorType>(DestTy)) { + // Beware: messing with this target-specific oddity may cause trouble. + if (DestVTy->getNumElements() == 1 && SrcTy->isX86_MMXTy()) { + Value *Elem = Builder.CreateBitCast(Src, DestVTy->getElementType()); + return InsertElementInst::Create(PoisonValue::get(DestTy), Elem, + Constant::getNullValue(Type::getInt32Ty(CI.getContext()))); + } + + if (isa<IntegerType>(SrcTy)) { + // If this is a cast from an integer to vector, check to see if the input + // is a trunc or zext of a bitcast from vector. If so, we can replace all + // the casts with a shuffle and (potentially) a bitcast. + if (isa<TruncInst>(Src) || isa<ZExtInst>(Src)) { + CastInst *SrcCast = cast<CastInst>(Src); + if (BitCastInst *BCIn = dyn_cast<BitCastInst>(SrcCast->getOperand(0))) + if (isa<VectorType>(BCIn->getOperand(0)->getType())) + if (Instruction *I = optimizeVectorResizeWithIntegerBitCasts( + BCIn->getOperand(0), cast<VectorType>(DestTy), *this)) + return I; + } + + // If the input is an 'or' instruction, we may be doing shifts and ors to + // assemble the elements of the vector manually. Try to rip the code out + // and replace it with insertelements. + if (Value *V = optimizeIntegerToVectorInsertions(CI, *this)) + return replaceInstUsesWith(CI, V); + } + } + + if (FixedVectorType *SrcVTy = dyn_cast<FixedVectorType>(SrcTy)) { + if (SrcVTy->getNumElements() == 1) { + // If our destination is not a vector, then make this a straight + // scalar-scalar cast. + if (!DestTy->isVectorTy()) { + Value *Elem = + Builder.CreateExtractElement(Src, + Constant::getNullValue(Type::getInt32Ty(CI.getContext()))); + return CastInst::Create(Instruction::BitCast, Elem, DestTy); + } + + // Otherwise, see if our source is an insert. If so, then use the scalar + // component directly: + // bitcast (inselt <1 x elt> V, X, 0) to <n x m> --> bitcast X to <n x m> + if (auto *InsElt = dyn_cast<InsertElementInst>(Src)) + return new BitCastInst(InsElt->getOperand(1), DestTy); + } + + // Convert an artificial vector insert into more analyzable bitwise logic. + unsigned BitWidth = DestTy->getScalarSizeInBits(); + Value *X, *Y; + uint64_t IndexC; + if (match(Src, m_OneUse(m_InsertElt(m_OneUse(m_BitCast(m_Value(X))), + m_Value(Y), m_ConstantInt(IndexC)))) && + DestTy->isIntegerTy() && X->getType() == DestTy && + Y->getType()->isIntegerTy() && isDesirableIntType(BitWidth)) { + // Adjust for big endian - the LSBs are at the high index. + if (DL.isBigEndian()) + IndexC = SrcVTy->getNumElements() - 1 - IndexC; + + // We only handle (endian-normalized) insert to index 0. Any other insert + // would require a left-shift, so that is an extra instruction. + if (IndexC == 0) { + // bitcast (inselt (bitcast X), Y, 0) --> or (and X, MaskC), (zext Y) + unsigned EltWidth = Y->getType()->getScalarSizeInBits(); + APInt MaskC = APInt::getHighBitsSet(BitWidth, BitWidth - EltWidth); + Value *AndX = Builder.CreateAnd(X, MaskC); + Value *ZextY = Builder.CreateZExt(Y, DestTy); + return BinaryOperator::CreateOr(AndX, ZextY); + } + } + } + + if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Src)) { + // Okay, we have (bitcast (shuffle ..)). Check to see if this is + // a bitcast to a vector with the same # elts. + Value *ShufOp0 = Shuf->getOperand(0); + Value *ShufOp1 = Shuf->getOperand(1); + auto ShufElts = cast<VectorType>(Shuf->getType())->getElementCount(); + auto SrcVecElts = cast<VectorType>(ShufOp0->getType())->getElementCount(); + if (Shuf->hasOneUse() && DestTy->isVectorTy() && + cast<VectorType>(DestTy)->getElementCount() == ShufElts && + ShufElts == SrcVecElts) { + BitCastInst *Tmp; + // If either of the operands is a cast from CI.getType(), then + // evaluating the shuffle in the casted destination's type will allow + // us to eliminate at least one cast. + if (((Tmp = dyn_cast<BitCastInst>(ShufOp0)) && + Tmp->getOperand(0)->getType() == DestTy) || + ((Tmp = dyn_cast<BitCastInst>(ShufOp1)) && + Tmp->getOperand(0)->getType() == DestTy)) { + Value *LHS = Builder.CreateBitCast(ShufOp0, DestTy); + Value *RHS = Builder.CreateBitCast(ShufOp1, DestTy); + // Return a new shuffle vector. Use the same element ID's, as we + // know the vector types match #elts. + return new ShuffleVectorInst(LHS, RHS, Shuf->getShuffleMask()); + } + } + + // A bitcasted-to-scalar and byte-reversing shuffle is better recognized as + // a byte-swap: + // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) --> bswap (bitcast X) + // TODO: We should match the related pattern for bitreverse. + if (DestTy->isIntegerTy() && + DL.isLegalInteger(DestTy->getScalarSizeInBits()) && + SrcTy->getScalarSizeInBits() == 8 && + ShufElts.getKnownMinValue() % 2 == 0 && Shuf->hasOneUse() && + Shuf->isReverse()) { + assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask"); + assert(match(ShufOp1, m_Undef()) && "Unexpected shuffle op"); + Function *Bswap = + Intrinsic::getDeclaration(CI.getModule(), Intrinsic::bswap, DestTy); + Value *ScalarX = Builder.CreateBitCast(ShufOp0, DestTy); + return CallInst::Create(Bswap, { ScalarX }); + } + } + + // Handle the A->B->A cast, and there is an intervening PHI node. + if (PHINode *PN = dyn_cast<PHINode>(Src)) + if (Instruction *I = optimizeBitCastFromPhi(CI, PN)) + return I; + + if (Instruction *I = canonicalizeBitCastExtElt(CI, *this)) + return I; + + if (Instruction *I = foldBitCastBitwiseLogic(CI, Builder)) + return I; + + if (Instruction *I = foldBitCastSelect(CI, Builder)) + return I; + + if (SrcTy->isPointerTy()) + return commonPointerCastTransforms(CI); + return commonCastTransforms(CI); +} + +Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) { + // If the destination pointer element type is not the same as the source's + // first do a bitcast to the destination type, and then the addrspacecast. + // This allows the cast to be exposed to other transforms. + Value *Src = CI.getOperand(0); + PointerType *SrcTy = cast<PointerType>(Src->getType()->getScalarType()); + PointerType *DestTy = cast<PointerType>(CI.getType()->getScalarType()); + + if (!SrcTy->hasSameElementTypeAs(DestTy)) { + Type *MidTy = + PointerType::getWithSamePointeeType(DestTy, SrcTy->getAddressSpace()); + // Handle vectors of pointers. + if (VectorType *VT = dyn_cast<VectorType>(CI.getType())) + MidTy = VectorType::get(MidTy, VT->getElementCount()); + + Value *NewBitCast = Builder.CreateBitCast(Src, MidTy); + return new AddrSpaceCastInst(NewBitCast, CI.getType()); + } + + return commonPointerCastTransforms(CI); +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp new file mode 100644 index 000000000000..d1f89973caa1 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -0,0 +1,6900 @@ +//===- InstCombineCompares.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the visitICmp and visitFCmp functions. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/CmpInstAnalysis.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "instcombine" + +// How many times is a select replaced by one of its operands? +STATISTIC(NumSel, "Number of select opts"); + + +/// Compute Result = In1+In2, returning true if the result overflowed for this +/// type. +static bool addWithOverflow(APInt &Result, const APInt &In1, + const APInt &In2, bool IsSigned = false) { + bool Overflow; + if (IsSigned) + Result = In1.sadd_ov(In2, Overflow); + else + Result = In1.uadd_ov(In2, Overflow); + + return Overflow; +} + +/// Compute Result = In1-In2, returning true if the result overflowed for this +/// type. +static bool subWithOverflow(APInt &Result, const APInt &In1, + const APInt &In2, bool IsSigned = false) { + bool Overflow; + if (IsSigned) + Result = In1.ssub_ov(In2, Overflow); + else + Result = In1.usub_ov(In2, Overflow); + + return Overflow; +} + +/// Given an icmp instruction, return true if any use of this comparison is a +/// branch on sign bit comparison. +static bool hasBranchUse(ICmpInst &I) { + for (auto *U : I.users()) + if (isa<BranchInst>(U)) + return true; + return false; +} + +/// Returns true if the exploded icmp can be expressed as a signed comparison +/// to zero and updates the predicate accordingly. +/// The signedness of the comparison is preserved. +/// TODO: Refactor with decomposeBitTestICmp()? +static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { + if (!ICmpInst::isSigned(Pred)) + return false; + + if (C.isZero()) + return ICmpInst::isRelational(Pred); + + if (C.isOne()) { + if (Pred == ICmpInst::ICMP_SLT) { + Pred = ICmpInst::ICMP_SLE; + return true; + } + } else if (C.isAllOnes()) { + if (Pred == ICmpInst::ICMP_SGT) { + Pred = ICmpInst::ICMP_SGE; + return true; + } + } + + return false; +} + +/// This is called when we see this pattern: +/// cmp pred (load (gep GV, ...)), cmpcst +/// where GV is a global variable with a constant initializer. Try to simplify +/// this into some simple computation that does not need the load. For example +/// we can optimize "icmp eq (load (gep "foo", 0, i)), 0" into "icmp eq i, 3". +/// +/// If AndCst is non-null, then the loaded value is masked with that constant +/// before doing the comparison. This handles cases like "A[i]&4 == 0". +Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( + LoadInst *LI, GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, + ConstantInt *AndCst) { + if (LI->isVolatile() || LI->getType() != GEP->getResultElementType() || + GV->getValueType() != GEP->getSourceElementType() || + !GV->isConstant() || !GV->hasDefinitiveInitializer()) + return nullptr; + + Constant *Init = GV->getInitializer(); + if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init)) + return nullptr; + + uint64_t ArrayElementCount = Init->getType()->getArrayNumElements(); + // Don't blow up on huge arrays. + if (ArrayElementCount > MaxArraySizeForCombine) + return nullptr; + + // There are many forms of this optimization we can handle, for now, just do + // the simple index into a single-dimensional array. + // + // Require: GEP GV, 0, i {{, constant indices}} + if (GEP->getNumOperands() < 3 || + !isa<ConstantInt>(GEP->getOperand(1)) || + !cast<ConstantInt>(GEP->getOperand(1))->isZero() || + isa<Constant>(GEP->getOperand(2))) + return nullptr; + + // Check that indices after the variable are constants and in-range for the + // type they index. Collect the indices. This is typically for arrays of + // structs. + SmallVector<unsigned, 4> LaterIndices; + + Type *EltTy = Init->getType()->getArrayElementType(); + for (unsigned i = 3, e = GEP->getNumOperands(); i != e; ++i) { + ConstantInt *Idx = dyn_cast<ConstantInt>(GEP->getOperand(i)); + if (!Idx) return nullptr; // Variable index. + + uint64_t IdxVal = Idx->getZExtValue(); + if ((unsigned)IdxVal != IdxVal) return nullptr; // Too large array index. + + if (StructType *STy = dyn_cast<StructType>(EltTy)) + EltTy = STy->getElementType(IdxVal); + else if (ArrayType *ATy = dyn_cast<ArrayType>(EltTy)) { + if (IdxVal >= ATy->getNumElements()) return nullptr; + EltTy = ATy->getElementType(); + } else { + return nullptr; // Unknown type. + } + + LaterIndices.push_back(IdxVal); + } + + enum { Overdefined = -3, Undefined = -2 }; + + // Variables for our state machines. + + // FirstTrueElement/SecondTrueElement - Used to emit a comparison of the form + // "i == 47 | i == 87", where 47 is the first index the condition is true for, + // and 87 is the second (and last) index. FirstTrueElement is -2 when + // undefined, otherwise set to the first true element. SecondTrueElement is + // -2 when undefined, -3 when overdefined and >= 0 when that index is true. + int FirstTrueElement = Undefined, SecondTrueElement = Undefined; + + // FirstFalseElement/SecondFalseElement - Used to emit a comparison of the + // form "i != 47 & i != 87". Same state transitions as for true elements. + int FirstFalseElement = Undefined, SecondFalseElement = Undefined; + + /// TrueRangeEnd/FalseRangeEnd - In conjunction with First*Element, these + /// define a state machine that triggers for ranges of values that the index + /// is true or false for. This triggers on things like "abbbbc"[i] == 'b'. + /// This is -2 when undefined, -3 when overdefined, and otherwise the last + /// index in the range (inclusive). We use -2 for undefined here because we + /// use relative comparisons and don't want 0-1 to match -1. + int TrueRangeEnd = Undefined, FalseRangeEnd = Undefined; + + // MagicBitvector - This is a magic bitvector where we set a bit if the + // comparison is true for element 'i'. If there are 64 elements or less in + // the array, this will fully represent all the comparison results. + uint64_t MagicBitvector = 0; + + // Scan the array and see if one of our patterns matches. + Constant *CompareRHS = cast<Constant>(ICI.getOperand(1)); + for (unsigned i = 0, e = ArrayElementCount; i != e; ++i) { + Constant *Elt = Init->getAggregateElement(i); + if (!Elt) return nullptr; + + // If this is indexing an array of structures, get the structure element. + if (!LaterIndices.empty()) { + Elt = ConstantFoldExtractValueInstruction(Elt, LaterIndices); + if (!Elt) + return nullptr; + } + + // If the element is masked, handle it. + if (AndCst) Elt = ConstantExpr::getAnd(Elt, AndCst); + + // Find out if the comparison would be true or false for the i'th element. + Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt, + CompareRHS, DL, &TLI); + // If the result is undef for this element, ignore it. + if (isa<UndefValue>(C)) { + // Extend range state machines to cover this element in case there is an + // undef in the middle of the range. + if (TrueRangeEnd == (int)i-1) + TrueRangeEnd = i; + if (FalseRangeEnd == (int)i-1) + FalseRangeEnd = i; + continue; + } + + // If we can't compute the result for any of the elements, we have to give + // up evaluating the entire conditional. + if (!isa<ConstantInt>(C)) return nullptr; + + // Otherwise, we know if the comparison is true or false for this element, + // update our state machines. + bool IsTrueForElt = !cast<ConstantInt>(C)->isZero(); + + // State machine for single/double/range index comparison. + if (IsTrueForElt) { + // Update the TrueElement state machine. + if (FirstTrueElement == Undefined) + FirstTrueElement = TrueRangeEnd = i; // First true element. + else { + // Update double-compare state machine. + if (SecondTrueElement == Undefined) + SecondTrueElement = i; + else + SecondTrueElement = Overdefined; + + // Update range state machine. + if (TrueRangeEnd == (int)i-1) + TrueRangeEnd = i; + else + TrueRangeEnd = Overdefined; + } + } else { + // Update the FalseElement state machine. + if (FirstFalseElement == Undefined) + FirstFalseElement = FalseRangeEnd = i; // First false element. + else { + // Update double-compare state machine. + if (SecondFalseElement == Undefined) + SecondFalseElement = i; + else + SecondFalseElement = Overdefined; + + // Update range state machine. + if (FalseRangeEnd == (int)i-1) + FalseRangeEnd = i; + else + FalseRangeEnd = Overdefined; + } + } + + // If this element is in range, update our magic bitvector. + if (i < 64 && IsTrueForElt) + MagicBitvector |= 1ULL << i; + + // If all of our states become overdefined, bail out early. Since the + // predicate is expensive, only check it every 8 elements. This is only + // really useful for really huge arrays. + if ((i & 8) == 0 && i >= 64 && SecondTrueElement == Overdefined && + SecondFalseElement == Overdefined && TrueRangeEnd == Overdefined && + FalseRangeEnd == Overdefined) + return nullptr; + } + + // Now that we've scanned the entire array, emit our new comparison(s). We + // order the state machines in complexity of the generated code. + Value *Idx = GEP->getOperand(2); + + // If the index is larger than the pointer size of the target, truncate the + // index down like the GEP would do implicitly. We don't have to do this for + // an inbounds GEP because the index can't be out of range. + if (!GEP->isInBounds()) { + Type *IntPtrTy = DL.getIntPtrType(GEP->getType()); + unsigned PtrSize = IntPtrTy->getIntegerBitWidth(); + if (Idx->getType()->getPrimitiveSizeInBits().getFixedSize() > PtrSize) + Idx = Builder.CreateTrunc(Idx, IntPtrTy); + } + + // If inbounds keyword is not present, Idx * ElementSize can overflow. + // Let's assume that ElementSize is 2 and the wanted value is at offset 0. + // Then, there are two possible values for Idx to match offset 0: + // 0x00..00, 0x80..00. + // Emitting 'icmp eq Idx, 0' isn't correct in this case because the + // comparison is false if Idx was 0x80..00. + // We need to erase the highest countTrailingZeros(ElementSize) bits of Idx. + unsigned ElementSize = + DL.getTypeAllocSize(Init->getType()->getArrayElementType()); + auto MaskIdx = [&](Value* Idx){ + if (!GEP->isInBounds() && countTrailingZeros(ElementSize) != 0) { + Value *Mask = ConstantInt::get(Idx->getType(), -1); + Mask = Builder.CreateLShr(Mask, countTrailingZeros(ElementSize)); + Idx = Builder.CreateAnd(Idx, Mask); + } + return Idx; + }; + + // If the comparison is only true for one or two elements, emit direct + // comparisons. + if (SecondTrueElement != Overdefined) { + Idx = MaskIdx(Idx); + // None true -> false. + if (FirstTrueElement == Undefined) + return replaceInstUsesWith(ICI, Builder.getFalse()); + + Value *FirstTrueIdx = ConstantInt::get(Idx->getType(), FirstTrueElement); + + // True for one element -> 'i == 47'. + if (SecondTrueElement == Undefined) + return new ICmpInst(ICmpInst::ICMP_EQ, Idx, FirstTrueIdx); + + // True for two elements -> 'i == 47 | i == 72'. + Value *C1 = Builder.CreateICmpEQ(Idx, FirstTrueIdx); + Value *SecondTrueIdx = ConstantInt::get(Idx->getType(), SecondTrueElement); + Value *C2 = Builder.CreateICmpEQ(Idx, SecondTrueIdx); + return BinaryOperator::CreateOr(C1, C2); + } + + // If the comparison is only false for one or two elements, emit direct + // comparisons. + if (SecondFalseElement != Overdefined) { + Idx = MaskIdx(Idx); + // None false -> true. + if (FirstFalseElement == Undefined) + return replaceInstUsesWith(ICI, Builder.getTrue()); + + Value *FirstFalseIdx = ConstantInt::get(Idx->getType(), FirstFalseElement); + + // False for one element -> 'i != 47'. + if (SecondFalseElement == Undefined) + return new ICmpInst(ICmpInst::ICMP_NE, Idx, FirstFalseIdx); + + // False for two elements -> 'i != 47 & i != 72'. + Value *C1 = Builder.CreateICmpNE(Idx, FirstFalseIdx); + Value *SecondFalseIdx = ConstantInt::get(Idx->getType(),SecondFalseElement); + Value *C2 = Builder.CreateICmpNE(Idx, SecondFalseIdx); + return BinaryOperator::CreateAnd(C1, C2); + } + + // If the comparison can be replaced with a range comparison for the elements + // where it is true, emit the range check. + if (TrueRangeEnd != Overdefined) { + assert(TrueRangeEnd != FirstTrueElement && "Should emit single compare"); + Idx = MaskIdx(Idx); + + // Generate (i-FirstTrue) <u (TrueRangeEnd-FirstTrue+1). + if (FirstTrueElement) { + Value *Offs = ConstantInt::get(Idx->getType(), -FirstTrueElement); + Idx = Builder.CreateAdd(Idx, Offs); + } + + Value *End = ConstantInt::get(Idx->getType(), + TrueRangeEnd-FirstTrueElement+1); + return new ICmpInst(ICmpInst::ICMP_ULT, Idx, End); + } + + // False range check. + if (FalseRangeEnd != Overdefined) { + assert(FalseRangeEnd != FirstFalseElement && "Should emit single compare"); + Idx = MaskIdx(Idx); + // Generate (i-FirstFalse) >u (FalseRangeEnd-FirstFalse). + if (FirstFalseElement) { + Value *Offs = ConstantInt::get(Idx->getType(), -FirstFalseElement); + Idx = Builder.CreateAdd(Idx, Offs); + } + + Value *End = ConstantInt::get(Idx->getType(), + FalseRangeEnd-FirstFalseElement); + return new ICmpInst(ICmpInst::ICMP_UGT, Idx, End); + } + + // If a magic bitvector captures the entire comparison state + // of this load, replace it with computation that does: + // ((magic_cst >> i) & 1) != 0 + { + Type *Ty = nullptr; + + // Look for an appropriate type: + // - The type of Idx if the magic fits + // - The smallest fitting legal type + if (ArrayElementCount <= Idx->getType()->getIntegerBitWidth()) + Ty = Idx->getType(); + else + Ty = DL.getSmallestLegalIntType(Init->getContext(), ArrayElementCount); + + if (Ty) { + Idx = MaskIdx(Idx); + Value *V = Builder.CreateIntCast(Idx, Ty, false); + V = Builder.CreateLShr(ConstantInt::get(Ty, MagicBitvector), V); + V = Builder.CreateAnd(ConstantInt::get(Ty, 1), V); + return new ICmpInst(ICmpInst::ICMP_NE, V, ConstantInt::get(Ty, 0)); + } + } + + return nullptr; +} + +/// Return a value that can be used to compare the *offset* implied by a GEP to +/// zero. For example, if we have &A[i], we want to return 'i' for +/// "icmp ne i, 0". Note that, in general, indices can be complex, and scales +/// are involved. The above expression would also be legal to codegen as +/// "icmp ne (i*4), 0" (assuming A is a pointer to i32). +/// This latter form is less amenable to optimization though, and we are allowed +/// to generate the first by knowing that pointer arithmetic doesn't overflow. +/// +/// If we can't emit an optimized form for this expression, this returns null. +/// +static Value *evaluateGEPOffsetExpression(User *GEP, InstCombinerImpl &IC, + const DataLayout &DL) { + gep_type_iterator GTI = gep_type_begin(GEP); + + // Check to see if this gep only has a single variable index. If so, and if + // any constant indices are a multiple of its scale, then we can compute this + // in terms of the scale of the variable index. For example, if the GEP + // implies an offset of "12 + i*4", then we can codegen this as "3 + i", + // because the expression will cross zero at the same point. + unsigned i, e = GEP->getNumOperands(); + int64_t Offset = 0; + for (i = 1; i != e; ++i, ++GTI) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) { + // Compute the aggregate offset of constant indices. + if (CI->isZero()) continue; + + // Handle a struct index, which adds its field offset to the pointer. + if (StructType *STy = GTI.getStructTypeOrNull()) { + Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); + } else { + uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); + Offset += Size*CI->getSExtValue(); + } + } else { + // Found our variable index. + break; + } + } + + // If there are no variable indices, we must have a constant offset, just + // evaluate it the general way. + if (i == e) return nullptr; + + Value *VariableIdx = GEP->getOperand(i); + // Determine the scale factor of the variable element. For example, this is + // 4 if the variable index is into an array of i32. + uint64_t VariableScale = DL.getTypeAllocSize(GTI.getIndexedType()); + + // Verify that there are no other variable indices. If so, emit the hard way. + for (++i, ++GTI; i != e; ++i, ++GTI) { + ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i)); + if (!CI) return nullptr; + + // Compute the aggregate offset of constant indices. + if (CI->isZero()) continue; + + // Handle a struct index, which adds its field offset to the pointer. + if (StructType *STy = GTI.getStructTypeOrNull()) { + Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); + } else { + uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); + Offset += Size*CI->getSExtValue(); + } + } + + // Okay, we know we have a single variable index, which must be a + // pointer/array/vector index. If there is no offset, life is simple, return + // the index. + Type *IntPtrTy = DL.getIntPtrType(GEP->getOperand(0)->getType()); + unsigned IntPtrWidth = IntPtrTy->getIntegerBitWidth(); + if (Offset == 0) { + // Cast to intptrty in case a truncation occurs. If an extension is needed, + // we don't need to bother extending: the extension won't affect where the + // computation crosses zero. + if (VariableIdx->getType()->getPrimitiveSizeInBits().getFixedSize() > + IntPtrWidth) { + VariableIdx = IC.Builder.CreateTrunc(VariableIdx, IntPtrTy); + } + return VariableIdx; + } + + // Otherwise, there is an index. The computation we will do will be modulo + // the pointer size. + Offset = SignExtend64(Offset, IntPtrWidth); + VariableScale = SignExtend64(VariableScale, IntPtrWidth); + + // To do this transformation, any constant index must be a multiple of the + // variable scale factor. For example, we can evaluate "12 + 4*i" as "3 + i", + // but we can't evaluate "10 + 3*i" in terms of i. Check that the offset is a + // multiple of the variable scale. + int64_t NewOffs = Offset / (int64_t)VariableScale; + if (Offset != NewOffs*(int64_t)VariableScale) + return nullptr; + + // Okay, we can do this evaluation. Start by converting the index to intptr. + if (VariableIdx->getType() != IntPtrTy) + VariableIdx = IC.Builder.CreateIntCast(VariableIdx, IntPtrTy, + true /*Signed*/); + Constant *OffsetVal = ConstantInt::get(IntPtrTy, NewOffs); + return IC.Builder.CreateAdd(VariableIdx, OffsetVal, "offset"); +} + +/// Returns true if we can rewrite Start as a GEP with pointer Base +/// and some integer offset. The nodes that need to be re-written +/// for this transformation will be added to Explored. +static bool canRewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, + const DataLayout &DL, + SetVector<Value *> &Explored) { + SmallVector<Value *, 16> WorkList(1, Start); + Explored.insert(Base); + + // The following traversal gives us an order which can be used + // when doing the final transformation. Since in the final + // transformation we create the PHI replacement instructions first, + // we don't have to get them in any particular order. + // + // However, for other instructions we will have to traverse the + // operands of an instruction first, which means that we have to + // do a post-order traversal. + while (!WorkList.empty()) { + SetVector<PHINode *> PHIs; + + while (!WorkList.empty()) { + if (Explored.size() >= 100) + return false; + + Value *V = WorkList.back(); + + if (Explored.contains(V)) { + WorkList.pop_back(); + continue; + } + + if (!isa<IntToPtrInst>(V) && !isa<PtrToIntInst>(V) && + !isa<GetElementPtrInst>(V) && !isa<PHINode>(V)) + // We've found some value that we can't explore which is different from + // the base. Therefore we can't do this transformation. + return false; + + if (isa<IntToPtrInst>(V) || isa<PtrToIntInst>(V)) { + auto *CI = cast<CastInst>(V); + if (!CI->isNoopCast(DL)) + return false; + + if (!Explored.contains(CI->getOperand(0))) + WorkList.push_back(CI->getOperand(0)); + } + + if (auto *GEP = dyn_cast<GEPOperator>(V)) { + // We're limiting the GEP to having one index. This will preserve + // the original pointer type. We could handle more cases in the + // future. + if (GEP->getNumIndices() != 1 || !GEP->isInBounds() || + GEP->getSourceElementType() != ElemTy) + return false; + + if (!Explored.contains(GEP->getOperand(0))) + WorkList.push_back(GEP->getOperand(0)); + } + + if (WorkList.back() == V) { + WorkList.pop_back(); + // We've finished visiting this node, mark it as such. + Explored.insert(V); + } + + if (auto *PN = dyn_cast<PHINode>(V)) { + // We cannot transform PHIs on unsplittable basic blocks. + if (isa<CatchSwitchInst>(PN->getParent()->getTerminator())) + return false; + Explored.insert(PN); + PHIs.insert(PN); + } + } + + // Explore the PHI nodes further. + for (auto *PN : PHIs) + for (Value *Op : PN->incoming_values()) + if (!Explored.contains(Op)) + WorkList.push_back(Op); + } + + // Make sure that we can do this. Since we can't insert GEPs in a basic + // block before a PHI node, we can't easily do this transformation if + // we have PHI node users of transformed instructions. + for (Value *Val : Explored) { + for (Value *Use : Val->uses()) { + + auto *PHI = dyn_cast<PHINode>(Use); + auto *Inst = dyn_cast<Instruction>(Val); + + if (Inst == Base || Inst == PHI || !Inst || !PHI || + !Explored.contains(PHI)) + continue; + + if (PHI->getParent() == Inst->getParent()) + return false; + } + } + return true; +} + +// Sets the appropriate insert point on Builder where we can add +// a replacement Instruction for V (if that is possible). +static void setInsertionPoint(IRBuilder<> &Builder, Value *V, + bool Before = true) { + if (auto *PHI = dyn_cast<PHINode>(V)) { + Builder.SetInsertPoint(&*PHI->getParent()->getFirstInsertionPt()); + return; + } + if (auto *I = dyn_cast<Instruction>(V)) { + if (!Before) + I = &*std::next(I->getIterator()); + Builder.SetInsertPoint(I); + return; + } + if (auto *A = dyn_cast<Argument>(V)) { + // Set the insertion point in the entry block. + BasicBlock &Entry = A->getParent()->getEntryBlock(); + Builder.SetInsertPoint(&*Entry.getFirstInsertionPt()); + return; + } + // Otherwise, this is a constant and we don't need to set a new + // insertion point. + assert(isa<Constant>(V) && "Setting insertion point for unknown value!"); +} + +/// Returns a re-written value of Start as an indexed GEP using Base as a +/// pointer. +static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, + const DataLayout &DL, + SetVector<Value *> &Explored) { + // Perform all the substitutions. This is a bit tricky because we can + // have cycles in our use-def chains. + // 1. Create the PHI nodes without any incoming values. + // 2. Create all the other values. + // 3. Add the edges for the PHI nodes. + // 4. Emit GEPs to get the original pointers. + // 5. Remove the original instructions. + Type *IndexType = IntegerType::get( + Base->getContext(), DL.getIndexTypeSizeInBits(Start->getType())); + + DenseMap<Value *, Value *> NewInsts; + NewInsts[Base] = ConstantInt::getNullValue(IndexType); + + // Create the new PHI nodes, without adding any incoming values. + for (Value *Val : Explored) { + if (Val == Base) + continue; + // Create empty phi nodes. This avoids cyclic dependencies when creating + // the remaining instructions. + if (auto *PHI = dyn_cast<PHINode>(Val)) + NewInsts[PHI] = PHINode::Create(IndexType, PHI->getNumIncomingValues(), + PHI->getName() + ".idx", PHI); + } + IRBuilder<> Builder(Base->getContext()); + + // Create all the other instructions. + for (Value *Val : Explored) { + + if (NewInsts.find(Val) != NewInsts.end()) + continue; + + if (auto *CI = dyn_cast<CastInst>(Val)) { + // Don't get rid of the intermediate variable here; the store can grow + // the map which will invalidate the reference to the input value. + Value *V = NewInsts[CI->getOperand(0)]; + NewInsts[CI] = V; + continue; + } + if (auto *GEP = dyn_cast<GEPOperator>(Val)) { + Value *Index = NewInsts[GEP->getOperand(1)] ? NewInsts[GEP->getOperand(1)] + : GEP->getOperand(1); + setInsertionPoint(Builder, GEP); + // Indices might need to be sign extended. GEPs will magically do + // this, but we need to do it ourselves here. + if (Index->getType()->getScalarSizeInBits() != + NewInsts[GEP->getOperand(0)]->getType()->getScalarSizeInBits()) { + Index = Builder.CreateSExtOrTrunc( + Index, NewInsts[GEP->getOperand(0)]->getType(), + GEP->getOperand(0)->getName() + ".sext"); + } + + auto *Op = NewInsts[GEP->getOperand(0)]; + if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero()) + NewInsts[GEP] = Index; + else + NewInsts[GEP] = Builder.CreateNSWAdd( + Op, Index, GEP->getOperand(0)->getName() + ".add"); + continue; + } + if (isa<PHINode>(Val)) + continue; + + llvm_unreachable("Unexpected instruction type"); + } + + // Add the incoming values to the PHI nodes. + for (Value *Val : Explored) { + if (Val == Base) + continue; + // All the instructions have been created, we can now add edges to the + // phi nodes. + if (auto *PHI = dyn_cast<PHINode>(Val)) { + PHINode *NewPhi = static_cast<PHINode *>(NewInsts[PHI]); + for (unsigned I = 0, E = PHI->getNumIncomingValues(); I < E; ++I) { + Value *NewIncoming = PHI->getIncomingValue(I); + + if (NewInsts.find(NewIncoming) != NewInsts.end()) + NewIncoming = NewInsts[NewIncoming]; + + NewPhi->addIncoming(NewIncoming, PHI->getIncomingBlock(I)); + } + } + } + + PointerType *PtrTy = + ElemTy->getPointerTo(Start->getType()->getPointerAddressSpace()); + for (Value *Val : Explored) { + if (Val == Base) + continue; + + // Depending on the type, for external users we have to emit + // a GEP or a GEP + ptrtoint. + setInsertionPoint(Builder, Val, false); + + // Cast base to the expected type. + Value *NewVal = Builder.CreateBitOrPointerCast( + Base, PtrTy, Start->getName() + "to.ptr"); + NewVal = Builder.CreateInBoundsGEP( + ElemTy, NewVal, makeArrayRef(NewInsts[Val]), Val->getName() + ".ptr"); + NewVal = Builder.CreateBitOrPointerCast( + NewVal, Val->getType(), Val->getName() + ".conv"); + Val->replaceAllUsesWith(NewVal); + } + + return NewInsts[Start]; +} + +/// Looks through GEPs, IntToPtrInsts and PtrToIntInsts in order to express +/// the input Value as a constant indexed GEP. Returns a pair containing +/// the GEPs Pointer and Index. +static std::pair<Value *, Value *> +getAsConstantIndexedAddress(Type *ElemTy, Value *V, const DataLayout &DL) { + Type *IndexType = IntegerType::get(V->getContext(), + DL.getIndexTypeSizeInBits(V->getType())); + + Constant *Index = ConstantInt::getNullValue(IndexType); + while (true) { + if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { + // We accept only inbouds GEPs here to exclude the possibility of + // overflow. + if (!GEP->isInBounds()) + break; + if (GEP->hasAllConstantIndices() && GEP->getNumIndices() == 1 && + GEP->getSourceElementType() == ElemTy) { + V = GEP->getOperand(0); + Constant *GEPIndex = static_cast<Constant *>(GEP->getOperand(1)); + Index = ConstantExpr::getAdd( + Index, ConstantExpr::getSExtOrTrunc(GEPIndex, IndexType)); + continue; + } + break; + } + if (auto *CI = dyn_cast<IntToPtrInst>(V)) { + if (!CI->isNoopCast(DL)) + break; + V = CI->getOperand(0); + continue; + } + if (auto *CI = dyn_cast<PtrToIntInst>(V)) { + if (!CI->isNoopCast(DL)) + break; + V = CI->getOperand(0); + continue; + } + break; + } + return {V, Index}; +} + +/// Converts (CMP GEPLHS, RHS) if this change would make RHS a constant. +/// We can look through PHIs, GEPs and casts in order to determine a common base +/// between GEPLHS and RHS. +static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, + ICmpInst::Predicate Cond, + const DataLayout &DL) { + // FIXME: Support vector of pointers. + if (GEPLHS->getType()->isVectorTy()) + return nullptr; + + if (!GEPLHS->hasAllConstantIndices()) + return nullptr; + + Type *ElemTy = GEPLHS->getSourceElementType(); + Value *PtrBase, *Index; + std::tie(PtrBase, Index) = getAsConstantIndexedAddress(ElemTy, GEPLHS, DL); + + // The set of nodes that will take part in this transformation. + SetVector<Value *> Nodes; + + if (!canRewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes)) + return nullptr; + + // We know we can re-write this as + // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) + // Since we've only looked through inbouds GEPs we know that we + // can't have overflow on either side. We can therefore re-write + // this as: + // OFFSET1 cmp OFFSET2 + Value *NewRHS = rewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes); + + // RewriteGEPAsOffset has replaced RHS and all of its uses with a re-written + // GEP having PtrBase as the pointer base, and has returned in NewRHS the + // offset. Since Index is the offset of LHS to the base pointer, we will now + // compare the offsets instead of comparing the pointers. + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Index, NewRHS); +} + +/// Fold comparisons between a GEP instruction and something else. At this point +/// we know that the GEP is on the LHS of the comparison. +Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, + ICmpInst::Predicate Cond, + Instruction &I) { + // Don't transform signed compares of GEPs into index compares. Even if the + // GEP is inbounds, the final add of the base pointer can have signed overflow + // and would change the result of the icmp. + // e.g. "&foo[0] <s &foo[1]" can't be folded to "true" because "foo" could be + // the maximum signed value for the pointer type. + if (ICmpInst::isSigned(Cond)) + return nullptr; + + // Look through bitcasts and addrspacecasts. We do not however want to remove + // 0 GEPs. + if (!isa<GetElementPtrInst>(RHS)) + RHS = RHS->stripPointerCasts(); + + Value *PtrBase = GEPLHS->getOperand(0); + // FIXME: Support vector pointer GEPs. + if (PtrBase == RHS && GEPLHS->isInBounds() && + !GEPLHS->getType()->isVectorTy()) { + // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). + // This transformation (ignoring the base and scales) is valid because we + // know pointers can't overflow since the gep is inbounds. See if we can + // output an optimized form. + Value *Offset = evaluateGEPOffsetExpression(GEPLHS, *this, DL); + + // If not, synthesize the offset the hard way. + if (!Offset) + Offset = EmitGEPOffset(GEPLHS); + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, + Constant::getNullValue(Offset->getType())); + } + + if (GEPLHS->isInBounds() && ICmpInst::isEquality(Cond) && + isa<Constant>(RHS) && cast<Constant>(RHS)->isNullValue() && + !NullPointerIsDefined(I.getFunction(), + RHS->getType()->getPointerAddressSpace())) { + // For most address spaces, an allocation can't be placed at null, but null + // itself is treated as a 0 size allocation in the in bounds rules. Thus, + // the only valid inbounds address derived from null, is null itself. + // Thus, we have four cases to consider: + // 1) Base == nullptr, Offset == 0 -> inbounds, null + // 2) Base == nullptr, Offset != 0 -> poison as the result is out of bounds + // 3) Base != nullptr, Offset == (-base) -> poison (crossing allocations) + // 4) Base != nullptr, Offset != (-base) -> nonnull (and possibly poison) + // + // (Note if we're indexing a type of size 0, that simply collapses into one + // of the buckets above.) + // + // In general, we're allowed to make values less poison (i.e. remove + // sources of full UB), so in this case, we just select between the two + // non-poison cases (1 and 4 above). + // + // For vectors, we apply the same reasoning on a per-lane basis. + auto *Base = GEPLHS->getPointerOperand(); + if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) { + auto EC = cast<VectorType>(GEPLHS->getType())->getElementCount(); + Base = Builder.CreateVectorSplat(EC, Base); + } + return new ICmpInst(Cond, Base, + ConstantExpr::getPointerBitCastOrAddrSpaceCast( + cast<Constant>(RHS), Base->getType())); + } else if (GEPOperator *GEPRHS = dyn_cast<GEPOperator>(RHS)) { + // If the base pointers are different, but the indices are the same, just + // compare the base pointer. + if (PtrBase != GEPRHS->getOperand(0)) { + bool IndicesTheSame = + GEPLHS->getNumOperands() == GEPRHS->getNumOperands() && + GEPLHS->getPointerOperand()->getType() == + GEPRHS->getPointerOperand()->getType() && + GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType(); + if (IndicesTheSame) + for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i) + if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) { + IndicesTheSame = false; + break; + } + + // If all indices are the same, just compare the base pointers. + Type *BaseType = GEPLHS->getOperand(0)->getType(); + if (IndicesTheSame && CmpInst::makeCmpResultType(BaseType) == I.getType()) + return new ICmpInst(Cond, GEPLHS->getOperand(0), GEPRHS->getOperand(0)); + + // If we're comparing GEPs with two base pointers that only differ in type + // and both GEPs have only constant indices or just one use, then fold + // the compare with the adjusted indices. + // FIXME: Support vector of pointers. + if (GEPLHS->isInBounds() && GEPRHS->isInBounds() && + (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && + (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) && + PtrBase->stripPointerCasts() == + GEPRHS->getOperand(0)->stripPointerCasts() && + !GEPLHS->getType()->isVectorTy()) { + Value *LOffset = EmitGEPOffset(GEPLHS); + Value *ROffset = EmitGEPOffset(GEPRHS); + + // If we looked through an addrspacecast between different sized address + // spaces, the LHS and RHS pointers are different sized + // integers. Truncate to the smaller one. + Type *LHSIndexTy = LOffset->getType(); + Type *RHSIndexTy = ROffset->getType(); + if (LHSIndexTy != RHSIndexTy) { + if (LHSIndexTy->getPrimitiveSizeInBits().getFixedSize() < + RHSIndexTy->getPrimitiveSizeInBits().getFixedSize()) { + ROffset = Builder.CreateTrunc(ROffset, LHSIndexTy); + } else + LOffset = Builder.CreateTrunc(LOffset, RHSIndexTy); + } + + Value *Cmp = Builder.CreateICmp(ICmpInst::getSignedPredicate(Cond), + LOffset, ROffset); + return replaceInstUsesWith(I, Cmp); + } + + // Otherwise, the base pointers are different and the indices are + // different. Try convert this to an indexed compare by looking through + // PHIs/casts. + return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); + } + + // If one of the GEPs has all zero indices, recurse. + // FIXME: Handle vector of pointers. + if (!GEPLHS->getType()->isVectorTy() && GEPLHS->hasAllZeroIndices()) + return foldGEPICmp(GEPRHS, GEPLHS->getOperand(0), + ICmpInst::getSwappedPredicate(Cond), I); + + // If the other GEP has all zero indices, recurse. + // FIXME: Handle vector of pointers. + if (!GEPRHS->getType()->isVectorTy() && GEPRHS->hasAllZeroIndices()) + return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); + + bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); + if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands() && + GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType()) { + // If the GEPs only differ by one index, compare it. + unsigned NumDifferences = 0; // Keep track of # differences. + unsigned DiffOperand = 0; // The operand that differs. + for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i) + if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) { + Type *LHSType = GEPLHS->getOperand(i)->getType(); + Type *RHSType = GEPRHS->getOperand(i)->getType(); + // FIXME: Better support for vector of pointers. + if (LHSType->getPrimitiveSizeInBits() != + RHSType->getPrimitiveSizeInBits() || + (GEPLHS->getType()->isVectorTy() && + (!LHSType->isVectorTy() || !RHSType->isVectorTy()))) { + // Irreconcilable differences. + NumDifferences = 2; + break; + } + + if (NumDifferences++) break; + DiffOperand = i; + } + + if (NumDifferences == 0) // SAME GEP? + return replaceInstUsesWith(I, // No comparison is needed here. + ConstantInt::get(I.getType(), ICmpInst::isTrueWhenEqual(Cond))); + + else if (NumDifferences == 1 && GEPsInBounds) { + Value *LHSV = GEPLHS->getOperand(DiffOperand); + Value *RHSV = GEPRHS->getOperand(DiffOperand); + // Make sure we do a signed comparison here. + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), LHSV, RHSV); + } + } + + // Only lower this if the icmp is the only user of the GEP or if we expect + // the result to fold to a constant! + if (GEPsInBounds && (isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) && + (isa<ConstantExpr>(GEPRHS) || GEPRHS->hasOneUse())) { + // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) + Value *L = EmitGEPOffset(GEPLHS); + Value *R = EmitGEPOffset(GEPRHS); + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), L, R); + } + } + + // Try convert this to an indexed compare by looking through PHIs/casts as a + // last resort. + return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); +} + +Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI, + const AllocaInst *Alloca) { + assert(ICI.isEquality() && "Cannot fold non-equality comparison."); + + // It would be tempting to fold away comparisons between allocas and any + // pointer not based on that alloca (e.g. an argument). However, even + // though such pointers cannot alias, they can still compare equal. + // + // But LLVM doesn't specify where allocas get their memory, so if the alloca + // doesn't escape we can argue that it's impossible to guess its value, and we + // can therefore act as if any such guesses are wrong. + // + // The code below checks that the alloca doesn't escape, and that it's only + // used in a comparison once (the current instruction). The + // single-comparison-use condition ensures that we're trivially folding all + // comparisons against the alloca consistently, and avoids the risk of + // erroneously folding a comparison of the pointer with itself. + + unsigned MaxIter = 32; // Break cycles and bound to constant-time. + + SmallVector<const Use *, 32> Worklist; + for (const Use &U : Alloca->uses()) { + if (Worklist.size() >= MaxIter) + return nullptr; + Worklist.push_back(&U); + } + + unsigned NumCmps = 0; + while (!Worklist.empty()) { + assert(Worklist.size() <= MaxIter); + const Use *U = Worklist.pop_back_val(); + const Value *V = U->getUser(); + --MaxIter; + + if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V) || isa<PHINode>(V) || + isa<SelectInst>(V)) { + // Track the uses. + } else if (isa<LoadInst>(V)) { + // Loading from the pointer doesn't escape it. + continue; + } else if (const auto *SI = dyn_cast<StoreInst>(V)) { + // Storing *to* the pointer is fine, but storing the pointer escapes it. + if (SI->getValueOperand() == U->get()) + return nullptr; + continue; + } else if (isa<ICmpInst>(V)) { + if (NumCmps++) + return nullptr; // Found more than one cmp. + continue; + } else if (const auto *Intrin = dyn_cast<IntrinsicInst>(V)) { + switch (Intrin->getIntrinsicID()) { + // These intrinsics don't escape or compare the pointer. Memset is safe + // because we don't allow ptrtoint. Memcpy and memmove are safe because + // we don't allow stores, so src cannot point to V. + case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: + case Intrinsic::memcpy: case Intrinsic::memmove: case Intrinsic::memset: + continue; + default: + return nullptr; + } + } else { + return nullptr; + } + for (const Use &U : V->uses()) { + if (Worklist.size() >= MaxIter) + return nullptr; + Worklist.push_back(&U); + } + } + + auto *Res = ConstantInt::get(ICI.getType(), + !CmpInst::isTrueWhenEqual(ICI.getPredicate())); + return replaceInstUsesWith(ICI, Res); +} + +/// Fold "icmp pred (X+C), X". +Instruction *InstCombinerImpl::foldICmpAddOpConst(Value *X, const APInt &C, + ICmpInst::Predicate Pred) { + // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, + // so the values can never be equal. Similarly for all other "or equals" + // operators. + assert(!!C && "C should not be zero!"); + + // (X+1) <u X --> X >u (MAXUINT-1) --> X == 255 + // (X+2) <u X --> X >u (MAXUINT-2) --> X > 253 + // (X+MAXUINT) <u X --> X >u (MAXUINT-MAXUINT) --> X != 0 + if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) { + Constant *R = ConstantInt::get(X->getType(), + APInt::getMaxValue(C.getBitWidth()) - C); + return new ICmpInst(ICmpInst::ICMP_UGT, X, R); + } + + // (X+1) >u X --> X <u (0-1) --> X != 255 + // (X+2) >u X --> X <u (0-2) --> X <u 254 + // (X+MAXUINT) >u X --> X <u (0-MAXUINT) --> X <u 1 --> X == 0 + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) + return new ICmpInst(ICmpInst::ICMP_ULT, X, + ConstantInt::get(X->getType(), -C)); + + APInt SMax = APInt::getSignedMaxValue(C.getBitWidth()); + + // (X+ 1) <s X --> X >s (MAXSINT-1) --> X == 127 + // (X+ 2) <s X --> X >s (MAXSINT-2) --> X >s 125 + // (X+MAXSINT) <s X --> X >s (MAXSINT-MAXSINT) --> X >s 0 + // (X+MINSINT) <s X --> X >s (MAXSINT-MINSINT) --> X >s -1 + // (X+ -2) <s X --> X >s (MAXSINT- -2) --> X >s 126 + // (X+ -1) <s X --> X >s (MAXSINT- -1) --> X != 127 + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_SGT, X, + ConstantInt::get(X->getType(), SMax - C)); + + // (X+ 1) >s X --> X <s (MAXSINT-(1-1)) --> X != 127 + // (X+ 2) >s X --> X <s (MAXSINT-(2-1)) --> X <s 126 + // (X+MAXSINT) >s X --> X <s (MAXSINT-(MAXSINT-1)) --> X <s 1 + // (X+MINSINT) >s X --> X <s (MAXSINT-(MINSINT-1)) --> X <s -2 + // (X+ -2) >s X --> X <s (MAXSINT-(-2-1)) --> X <s -126 + // (X+ -1) >s X --> X <s (MAXSINT-(-1-1)) --> X == -128 + + assert(Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE); + return new ICmpInst(ICmpInst::ICMP_SLT, X, + ConstantInt::get(X->getType(), SMax - (C - 1))); +} + +/// Handle "(icmp eq/ne (ashr/lshr AP2, A), AP1)" -> +/// (icmp eq/ne A, Log2(AP2/AP1)) -> +/// (icmp eq/ne A, Log2(AP2) - Log2(AP1)). +Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A, + const APInt &AP1, + const APInt &AP2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2.isZero()) + return nullptr; + + bool IsAShr = isa<AShrOperator>(I.getOperand(0)); + if (IsAShr) { + if (AP2.isAllOnes()) + return nullptr; + if (AP2.isNegative() != AP1.isNegative()) + return nullptr; + if (AP2.sgt(AP1)) + return nullptr; + } + + if (!AP1) + // 'A' must be large enough to shift out the highest set bit. + return getICmp(I.ICMP_UGT, A, + ConstantInt::get(A->getType(), AP2.logBase2())); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + int Shift; + if (IsAShr && AP1.isNegative()) + Shift = AP1.countLeadingOnes() - AP2.countLeadingOnes(); + else + Shift = AP1.countLeadingZeros() - AP2.countLeadingZeros(); + + if (Shift > 0) { + if (IsAShr && AP1 == AP2.ashr(Shift)) { + // There are multiple solutions if we are comparing against -1 and the LHS + // of the ashr is not a power of two. + if (AP1.isAllOnes() && !AP2.isPowerOf2()) + return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift)); + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } else if (AP1 == AP2.lshr(Shift)) { + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } + } + + // Shifting const2 will never be equal to const1. + // FIXME: This should always be handled by InstSimplify? + auto *TorF = ConstantInt::get(I.getType(), I.getPredicate() == I.ICMP_NE); + return replaceInstUsesWith(I, TorF); +} + +/// Handle "(icmp eq/ne (shl AP2, A), AP1)" -> +/// (icmp eq/ne A, TrailingZeros(AP1) - TrailingZeros(AP2)). +Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A, + const APInt &AP1, + const APInt &AP2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2.isZero()) + return nullptr; + + unsigned AP2TrailingZeros = AP2.countTrailingZeros(); + + if (!AP1 && AP2TrailingZeros != 0) + return getICmp( + I.ICMP_UGE, A, + ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the lowest bits that are set. + int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + + if (Shift > 0 && AP2.shl(Shift) == AP1) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + + // Shifting const2 will never be equal to const1. + // FIXME: This should always be handled by InstSimplify? + auto *TorF = ConstantInt::get(I.getType(), I.getPredicate() == I.ICMP_NE); + return replaceInstUsesWith(I, TorF); +} + +/// The caller has matched a pattern of the form: +/// I = icmp ugt (add (add A, B), CI2), CI1 +/// If this is of the form: +/// sum = a + b +/// if (sum+128 >u 255) +/// Then replace it with llvm.sadd.with.overflow.i8. +/// +static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, + ConstantInt *CI2, ConstantInt *CI1, + InstCombinerImpl &IC) { + // The transformation we're trying to do here is to transform this into an + // llvm.sadd.with.overflow. To do this, we have to replace the original add + // with a narrower add, and discard the add-with-constant that is part of the + // range check (if we can't eliminate it, this isn't profitable). + + // In order to eliminate the add-with-constant, the compare can be its only + // use. + Instruction *AddWithCst = cast<Instruction>(I.getOperand(0)); + if (!AddWithCst->hasOneUse()) + return nullptr; + + // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. + if (!CI2->getValue().isPowerOf2()) + return nullptr; + unsigned NewWidth = CI2->getValue().countTrailingZeros(); + if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) + return nullptr; + + // The width of the new add formed is 1 more than the bias. + ++NewWidth; + + // Check to see that CI1 is an all-ones value with NewWidth bits. + if (CI1->getBitWidth() == NewWidth || + CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) + return nullptr; + + // This is only really a signed overflow check if the inputs have been + // sign-extended; check for that condition. For example, if CI2 is 2^31 and + // the operands of the add are 64 bits wide, we need at least 33 sign bits. + if (IC.ComputeMaxSignificantBits(A, 0, &I) > NewWidth || + IC.ComputeMaxSignificantBits(B, 0, &I) > NewWidth) + return nullptr; + + // In order to replace the original add with a narrower + // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant + // and truncates that discard the high bits of the add. Verify that this is + // the case. + Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0)); + for (User *U : OrigAdd->users()) { + if (U == AddWithCst) + continue; + + // Only accept truncates for now. We would really like a nice recursive + // predicate like SimplifyDemandedBits, but which goes downwards the use-def + // chain to see which bits of a value are actually demanded. If the + // original add had another add which was then immediately truncated, we + // could still do the transformation. + TruncInst *TI = dyn_cast<TruncInst>(U); + if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) + return nullptr; + } + + // If the pattern matches, truncate the inputs to the narrower type and + // use the sadd_with_overflow intrinsic to efficiently compute both the + // result and the overflow bit. + Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); + Function *F = Intrinsic::getDeclaration( + I.getModule(), Intrinsic::sadd_with_overflow, NewType); + + InstCombiner::BuilderTy &Builder = IC.Builder; + + // Put the new code above the original add, in case there are any uses of the + // add between the add and the compare. + Builder.SetInsertPoint(OrigAdd); + + Value *TruncA = Builder.CreateTrunc(A, NewType, A->getName() + ".trunc"); + Value *TruncB = Builder.CreateTrunc(B, NewType, B->getName() + ".trunc"); + CallInst *Call = Builder.CreateCall(F, {TruncA, TruncB}, "sadd"); + Value *Add = Builder.CreateExtractValue(Call, 0, "sadd.result"); + Value *ZExt = Builder.CreateZExt(Add, OrigAdd->getType()); + + // The inner add was the result of the narrow add, zero extended to the + // wider type. Replace it with the result computed by the intrinsic. + IC.replaceInstUsesWith(*OrigAdd, ZExt); + IC.eraseInstFromFunction(*OrigAdd); + + // The original icmp gets replaced with the overflow value. + return ExtractValueInst::Create(Call, 1, "sadd.overflow"); +} + +/// If we have: +/// icmp eq/ne (urem/srem %x, %y), 0 +/// iff %y is a power-of-two, we can replace this with a bit test: +/// icmp eq/ne (and %x, (add %y, -1)), 0 +Instruction *InstCombinerImpl::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) { + // This fold is only valid for equality predicates. + if (!I.isEquality()) + return nullptr; + ICmpInst::Predicate Pred; + Value *X, *Y, *Zero; + if (!match(&I, m_ICmp(Pred, m_OneUse(m_IRem(m_Value(X), m_Value(Y))), + m_CombineAnd(m_Zero(), m_Value(Zero))))) + return nullptr; + if (!isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, 0, &I)) + return nullptr; + // This may increase instruction count, we don't enforce that Y is a constant. + Value *Mask = Builder.CreateAdd(Y, Constant::getAllOnesValue(Y->getType())); + Value *Masked = Builder.CreateAnd(X, Mask); + return ICmpInst::Create(Instruction::ICmp, Pred, Masked, Zero); +} + +/// Fold equality-comparison between zero and any (maybe truncated) right-shift +/// by one-less-than-bitwidth into a sign test on the original value. +Instruction *InstCombinerImpl::foldSignBitTest(ICmpInst &I) { + Instruction *Val; + ICmpInst::Predicate Pred; + if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero()))) + return nullptr; + + Value *X; + Type *XTy; + + Constant *C; + if (match(Val, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))))) { + XTy = X->getType(); + unsigned XBitWidth = XTy->getScalarSizeInBits(); + if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(XBitWidth, XBitWidth - 1)))) + return nullptr; + } else if (isa<BinaryOperator>(Val) && + (X = reassociateShiftAmtsOfTwoSameDirectionShifts( + cast<BinaryOperator>(Val), SQ.getWithInstruction(Val), + /*AnalyzeForSignBitExtraction=*/true))) { + XTy = X->getType(); + } else + return nullptr; + + return ICmpInst::Create(Instruction::ICmp, + Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_SGE + : ICmpInst::ICMP_SLT, + X, ConstantInt::getNullValue(XTy)); +} + +// Handle icmp pred X, 0 +Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) { + CmpInst::Predicate Pred = Cmp.getPredicate(); + if (!match(Cmp.getOperand(1), m_Zero())) + return nullptr; + + // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) + if (Pred == ICmpInst::ICMP_SGT) { + Value *A, *B; + if (match(Cmp.getOperand(0), m_SMin(m_Value(A), m_Value(B)))) { + if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) + return new ICmpInst(Pred, B, Cmp.getOperand(1)); + if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT)) + return new ICmpInst(Pred, A, Cmp.getOperand(1)); + } + } + + if (Instruction *New = foldIRemByPowerOfTwoToBitTest(Cmp)) + return New; + + // Given: + // icmp eq/ne (urem %x, %y), 0 + // Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem': + // icmp eq/ne %x, 0 + Value *X, *Y; + if (match(Cmp.getOperand(0), m_URem(m_Value(X), m_Value(Y))) && + ICmpInst::isEquality(Pred)) { + KnownBits XKnown = computeKnownBits(X, 0, &Cmp); + KnownBits YKnown = computeKnownBits(Y, 0, &Cmp); + if (XKnown.countMaxPopulation() == 1 && YKnown.countMinPopulation() >= 2) + return new ICmpInst(Pred, X, Cmp.getOperand(1)); + } + + return nullptr; +} + +/// Fold icmp Pred X, C. +/// TODO: This code structure does not make sense. The saturating add fold +/// should be moved to some other helper and extended as noted below (it is also +/// possible that code has been made unnecessary - do we canonicalize IR to +/// overflow/saturating intrinsics or not?). +Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) { + // Match the following pattern, which is a common idiom when writing + // overflow-safe integer arithmetic functions. The source performs an addition + // in wider type and explicitly checks for overflow using comparisons against + // INT_MIN and INT_MAX. Simplify by using the sadd_with_overflow intrinsic. + // + // TODO: This could probably be generalized to handle other overflow-safe + // operations if we worked out the formulas to compute the appropriate magic + // constants. + // + // sum = a + b + // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 + CmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op0 = Cmp.getOperand(0), *Op1 = Cmp.getOperand(1); + Value *A, *B; + ConstantInt *CI, *CI2; // I = icmp ugt (add (add A, B), CI2), CI + if (Pred == ICmpInst::ICMP_UGT && match(Op1, m_ConstantInt(CI)) && + match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) + if (Instruction *Res = processUGT_ADDCST_ADD(Cmp, A, B, CI2, CI, *this)) + return Res; + + // icmp(phi(C1, C2, ...), C) -> phi(icmp(C1, C), icmp(C2, C), ...). + Constant *C = dyn_cast<Constant>(Op1); + if (!C || C->canTrap()) + return nullptr; + + if (auto *Phi = dyn_cast<PHINode>(Op0)) + if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) { + Type *Ty = Cmp.getType(); + Builder.SetInsertPoint(Phi); + PHINode *NewPhi = + Builder.CreatePHI(Ty, Phi->getNumOperands()); + for (BasicBlock *Predecessor : predecessors(Phi->getParent())) { + auto *Input = + cast<Constant>(Phi->getIncomingValueForBlock(Predecessor)); + auto *BoolInput = ConstantExpr::getCompare(Pred, Input, C); + NewPhi->addIncoming(BoolInput, Predecessor); + } + NewPhi->takeName(&Cmp); + return replaceInstUsesWith(Cmp, NewPhi); + } + + return nullptr; +} + +/// Canonicalize icmp instructions based on dominating conditions. +Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { + // This is a cheap/incomplete check for dominance - just match a single + // predecessor with a conditional branch. + BasicBlock *CmpBB = Cmp.getParent(); + BasicBlock *DomBB = CmpBB->getSinglePredecessor(); + if (!DomBB) + return nullptr; + + Value *DomCond; + BasicBlock *TrueBB, *FalseBB; + if (!match(DomBB->getTerminator(), m_Br(m_Value(DomCond), TrueBB, FalseBB))) + return nullptr; + + assert((TrueBB == CmpBB || FalseBB == CmpBB) && + "Predecessor block does not point to successor?"); + + // The branch should get simplified. Don't bother simplifying this condition. + if (TrueBB == FalseBB) + return nullptr; + + // Try to simplify this compare to T/F based on the dominating condition. + Optional<bool> Imp = isImpliedCondition(DomCond, &Cmp, DL, TrueBB == CmpBB); + if (Imp) + return replaceInstUsesWith(Cmp, ConstantInt::get(Cmp.getType(), *Imp)); + + CmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Cmp.getOperand(0), *Y = Cmp.getOperand(1); + ICmpInst::Predicate DomPred; + const APInt *C, *DomC; + if (match(DomCond, m_ICmp(DomPred, m_Specific(X), m_APInt(DomC))) && + match(Y, m_APInt(C))) { + // We have 2 compares of a variable with constants. Calculate the constant + // ranges of those compares to see if we can transform the 2nd compare: + // DomBB: + // DomCond = icmp DomPred X, DomC + // br DomCond, CmpBB, FalseBB + // CmpBB: + // Cmp = icmp Pred X, C + ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *C); + ConstantRange DominatingCR = + (CmpBB == TrueBB) ? ConstantRange::makeExactICmpRegion(DomPred, *DomC) + : ConstantRange::makeExactICmpRegion( + CmpInst::getInversePredicate(DomPred), *DomC); + ConstantRange Intersection = DominatingCR.intersectWith(CR); + ConstantRange Difference = DominatingCR.difference(CR); + if (Intersection.isEmptySet()) + return replaceInstUsesWith(Cmp, Builder.getFalse()); + if (Difference.isEmptySet()) + return replaceInstUsesWith(Cmp, Builder.getTrue()); + + // Canonicalizing a sign bit comparison that gets used in a branch, + // pessimizes codegen by generating branch on zero instruction instead + // of a test and branch. So we avoid canonicalizing in such situations + // because test and branch instruction has better branch displacement + // than compare and branch instruction. + bool UnusedBit; + bool IsSignBit = isSignBitCheck(Pred, *C, UnusedBit); + if (Cmp.isEquality() || (IsSignBit && hasBranchUse(Cmp))) + return nullptr; + + // Avoid an infinite loop with min/max canonicalization. + // TODO: This will be unnecessary if we canonicalize to min/max intrinsics. + if (Cmp.hasOneUse() && + match(Cmp.user_back(), m_MaxOrMin(m_Value(), m_Value()))) + return nullptr; + + if (const APInt *EqC = Intersection.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder.getInt(*EqC)); + if (const APInt *NeC = Difference.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_NE, X, Builder.getInt(*NeC)); + } + + return nullptr; +} + +/// Fold icmp (trunc X), C. +Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, + TruncInst *Trunc, + const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Trunc->getOperand(0); + if (C.isOne() && C.getBitWidth() > 1) { + // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 + Value *V = nullptr; + if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V)))) + return new ICmpInst(ICmpInst::ICMP_SLT, V, + ConstantInt::get(V->getType(), 1)); + } + + unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), + SrcBits = X->getType()->getScalarSizeInBits(); + if (Cmp.isEquality() && Trunc->hasOneUse()) { + // Canonicalize to a mask and wider compare if the wide type is suitable: + // (trunc X to i8) == C --> (X & 0xff) == (zext C) + if (!X->getType()->isVectorTy() && shouldChangeType(DstBits, SrcBits)) { + Constant *Mask = ConstantInt::get(X->getType(), + APInt::getLowBitsSet(SrcBits, DstBits)); + Value *And = Builder.CreateAnd(X, Mask); + Constant *WideC = ConstantInt::get(X->getType(), C.zext(SrcBits)); + return new ICmpInst(Pred, And, WideC); + } + + // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all + // of the high bits truncated out of x are known. + KnownBits Known = computeKnownBits(X, 0, &Cmp); + + // If all the high bits are known, we can do this xform. + if ((Known.Zero | Known.One).countLeadingOnes() >= SrcBits - DstBits) { + // Pull in the high bits from known-ones set. + APInt NewRHS = C.zext(SrcBits); + NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS)); + } + } + + // Look through truncated right-shift of the sign-bit for a sign-bit check: + // trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] < 0 --> ShOp < 0 + // trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] > -1 --> ShOp > -1 + Value *ShOp; + const APInt *ShAmtC; + bool TrueIfSigned; + if (isSignBitCheck(Pred, C, TrueIfSigned) && + match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) && + DstBits == SrcBits - ShAmtC->getZExtValue()) { + return TrueIfSigned + ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp, + ConstantInt::getNullValue(X->getType())) + : new ICmpInst(ICmpInst::ICMP_SGT, ShOp, + ConstantInt::getAllOnesValue(X->getType())); + } + + return nullptr; +} + +/// Fold icmp (xor X, Y), C. +Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp, + BinaryOperator *Xor, + const APInt &C) { + Value *X = Xor->getOperand(0); + Value *Y = Xor->getOperand(1); + const APInt *XorC; + if (!match(Y, m_APInt(XorC))) + return nullptr; + + // If this is a comparison that tests the signbit (X < 0) or (x > -1), + // fold the xor. + ICmpInst::Predicate Pred = Cmp.getPredicate(); + bool TrueIfSigned = false; + if (isSignBitCheck(Cmp.getPredicate(), C, TrueIfSigned)) { + + // If the sign bit of the XorCst is not set, there is no change to + // the operation, just stop using the Xor. + if (!XorC->isNegative()) + return replaceOperand(Cmp, 0, X); + + // Emit the opposite comparison. + if (TrueIfSigned) + return new ICmpInst(ICmpInst::ICMP_SGT, X, + ConstantInt::getAllOnesValue(X->getType())); + else + return new ICmpInst(ICmpInst::ICMP_SLT, X, + ConstantInt::getNullValue(X->getType())); + } + + if (Xor->hasOneUse()) { + // (icmp u/s (xor X SignMask), C) -> (icmp s/u X, (xor C SignMask)) + if (!Cmp.isEquality() && XorC->isSignMask()) { + Pred = Cmp.getFlippedSignednessPredicate(); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC)); + } + + // (icmp u/s (xor X ~SignMask), C) -> (icmp s/u X, (xor C ~SignMask)) + if (!Cmp.isEquality() && XorC->isMaxSignedValue()) { + Pred = Cmp.getFlippedSignednessPredicate(); + Pred = Cmp.getSwappedPredicate(Pred); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC)); + } + } + + // Mask constant magic can eliminate an 'xor' with unsigned compares. + if (Pred == ICmpInst::ICMP_UGT) { + // (xor X, ~C) >u C --> X <u ~C (when C+1 is a power of 2) + if (*XorC == ~C && (C + 1).isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); + // (xor X, C) >u C --> X >u C (when C+1 is a power of 2) + if (*XorC == C && (C + 1).isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_UGT, X, Y); + } + if (Pred == ICmpInst::ICMP_ULT) { + // (xor X, -C) <u C --> X >u ~C (when C is a power of 2) + if (*XorC == -C && C.isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_UGT, X, + ConstantInt::get(X->getType(), ~C)); + // (xor X, C) <u C --> X >u ~C (when -C is a power of 2) + if (*XorC == C && (-C).isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_UGT, X, + ConstantInt::get(X->getType(), ~C)); + } + return nullptr; +} + +/// Fold icmp (and (sh X, Y), C2), C1. +Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp, + BinaryOperator *And, + const APInt &C1, + const APInt &C2) { + BinaryOperator *Shift = dyn_cast<BinaryOperator>(And->getOperand(0)); + if (!Shift || !Shift->isShift()) + return nullptr; + + // If this is: (X >> C3) & C2 != C1 (where any shift and any compare could + // exist), turn it into (X & (C2 << C3)) != (C1 << C3). This happens a LOT in + // code produced by the clang front-end, for bitfield access. + // This seemingly simple opportunity to fold away a shift turns out to be + // rather complicated. See PR17827 for details. + unsigned ShiftOpcode = Shift->getOpcode(); + bool IsShl = ShiftOpcode == Instruction::Shl; + const APInt *C3; + if (match(Shift->getOperand(1), m_APInt(C3))) { + APInt NewAndCst, NewCmpCst; + bool AnyCmpCstBitsShiftedOut; + if (ShiftOpcode == Instruction::Shl) { + // For a left shift, we can fold if the comparison is not signed. We can + // also fold a signed comparison if the mask value and comparison value + // are not negative. These constraints may not be obvious, but we can + // prove that they are correct using an SMT solver. + if (Cmp.isSigned() && (C2.isNegative() || C1.isNegative())) + return nullptr; + + NewCmpCst = C1.lshr(*C3); + NewAndCst = C2.lshr(*C3); + AnyCmpCstBitsShiftedOut = NewCmpCst.shl(*C3) != C1; + } else if (ShiftOpcode == Instruction::LShr) { + // For a logical right shift, we can fold if the comparison is not signed. + // We can also fold a signed comparison if the shifted mask value and the + // shifted comparison value are not negative. These constraints may not be + // obvious, but we can prove that they are correct using an SMT solver. + NewCmpCst = C1.shl(*C3); + NewAndCst = C2.shl(*C3); + AnyCmpCstBitsShiftedOut = NewCmpCst.lshr(*C3) != C1; + if (Cmp.isSigned() && (NewAndCst.isNegative() || NewCmpCst.isNegative())) + return nullptr; + } else { + // For an arithmetic shift, check that both constants don't use (in a + // signed sense) the top bits being shifted out. + assert(ShiftOpcode == Instruction::AShr && "Unknown shift opcode"); + NewCmpCst = C1.shl(*C3); + NewAndCst = C2.shl(*C3); + AnyCmpCstBitsShiftedOut = NewCmpCst.ashr(*C3) != C1; + if (NewAndCst.ashr(*C3) != C2) + return nullptr; + } + + if (AnyCmpCstBitsShiftedOut) { + // If we shifted bits out, the fold is not going to work out. As a + // special case, check to see if this means that the result is always + // true or false now. + if (Cmp.getPredicate() == ICmpInst::ICMP_EQ) + return replaceInstUsesWith(Cmp, ConstantInt::getFalse(Cmp.getType())); + if (Cmp.getPredicate() == ICmpInst::ICMP_NE) + return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType())); + } else { + Value *NewAnd = Builder.CreateAnd( + Shift->getOperand(0), ConstantInt::get(And->getType(), NewAndCst)); + return new ICmpInst(Cmp.getPredicate(), + NewAnd, ConstantInt::get(And->getType(), NewCmpCst)); + } + } + + // Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is + // preferable because it allows the C2 << Y expression to be hoisted out of a + // loop if Y is invariant and X is not. + if (Shift->hasOneUse() && C1.isZero() && Cmp.isEquality() && + !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { + // Compute C2 << Y. + Value *NewShift = + IsShl ? Builder.CreateLShr(And->getOperand(1), Shift->getOperand(1)) + : Builder.CreateShl(And->getOperand(1), Shift->getOperand(1)); + + // Compute X & (C2 << Y). + Value *NewAnd = Builder.CreateAnd(Shift->getOperand(0), NewShift); + return replaceOperand(Cmp, 0, NewAnd); + } + + return nullptr; +} + +/// Fold icmp (and X, C2), C1. +Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, + BinaryOperator *And, + const APInt &C1) { + bool isICMP_NE = Cmp.getPredicate() == ICmpInst::ICMP_NE; + + // For vectors: icmp ne (and X, 1), 0 --> trunc X to N x i1 + // TODO: We canonicalize to the longer form for scalars because we have + // better analysis/folds for icmp, and codegen may be better with icmp. + if (isICMP_NE && Cmp.getType()->isVectorTy() && C1.isZero() && + match(And->getOperand(1), m_One())) + return new TruncInst(And->getOperand(0), Cmp.getType()); + + const APInt *C2; + Value *X; + if (!match(And, m_And(m_Value(X), m_APInt(C2)))) + return nullptr; + + // Don't perform the following transforms if the AND has multiple uses + if (!And->hasOneUse()) + return nullptr; + + if (Cmp.isEquality() && C1.isZero()) { + // Restrict this fold to single-use 'and' (PR10267). + // Replace (and X, (1 << size(X)-1) != 0) with X s< 0 + if (C2->isSignMask()) { + Constant *Zero = Constant::getNullValue(X->getType()); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; + return new ICmpInst(NewPred, X, Zero); + } + + // Restrict this fold only for single-use 'and' (PR10267). + // ((%x & C) == 0) --> %x u< (-C) iff (-C) is power of two. + if ((~(*C2) + 1).isPowerOf2()) { + Constant *NegBOC = + ConstantExpr::getNeg(cast<Constant>(And->getOperand(1))); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; + return new ICmpInst(NewPred, X, NegBOC); + } + } + + // If the LHS is an 'and' of a truncate and we can widen the and/compare to + // the input width without changing the value produced, eliminate the cast: + // + // icmp (and (trunc W), C2), C1 -> icmp (and W, C2'), C1' + // + // We can do this transformation if the constants do not have their sign bits + // set or if it is an equality comparison. Extending a relational comparison + // when we're checking the sign bit would not work. + Value *W; + if (match(And->getOperand(0), m_OneUse(m_Trunc(m_Value(W)))) && + (Cmp.isEquality() || (!C1.isNegative() && !C2->isNegative()))) { + // TODO: Is this a good transform for vectors? Wider types may reduce + // throughput. Should this transform be limited (even for scalars) by using + // shouldChangeType()? + if (!Cmp.getType()->isVectorTy()) { + Type *WideType = W->getType(); + unsigned WideScalarBits = WideType->getScalarSizeInBits(); + Constant *ZextC1 = ConstantInt::get(WideType, C1.zext(WideScalarBits)); + Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits)); + Value *NewAnd = Builder.CreateAnd(W, ZextC2, And->getName()); + return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1); + } + } + + if (Instruction *I = foldICmpAndShift(Cmp, And, C1, *C2)) + return I; + + // (icmp pred (and (or (lshr A, B), A), 1), 0) --> + // (icmp pred (and A, (or (shl 1, B), 1), 0)) + // + // iff pred isn't signed + if (!Cmp.isSigned() && C1.isZero() && And->getOperand(0)->hasOneUse() && + match(And->getOperand(1), m_One())) { + Constant *One = cast<Constant>(And->getOperand(1)); + Value *Or = And->getOperand(0); + Value *A, *B, *LShr; + if (match(Or, m_Or(m_Value(LShr), m_Value(A))) && + match(LShr, m_LShr(m_Specific(A), m_Value(B)))) { + unsigned UsesRemoved = 0; + if (And->hasOneUse()) + ++UsesRemoved; + if (Or->hasOneUse()) + ++UsesRemoved; + if (LShr->hasOneUse()) + ++UsesRemoved; + + // Compute A & ((1 << B) | 1) + Value *NewOr = nullptr; + if (auto *C = dyn_cast<Constant>(B)) { + if (UsesRemoved >= 1) + NewOr = ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); + } else { + if (UsesRemoved >= 3) + NewOr = Builder.CreateOr(Builder.CreateShl(One, B, LShr->getName(), + /*HasNUW=*/true), + One, Or->getName()); + } + if (NewOr) { + Value *NewAnd = Builder.CreateAnd(A, NewOr, And->getName()); + return replaceOperand(Cmp, 0, NewAnd); + } + } + } + + return nullptr; +} + +/// Fold icmp (and X, Y), C. +Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, + BinaryOperator *And, + const APInt &C) { + if (Instruction *I = foldICmpAndConstConst(Cmp, And, C)) + return I; + + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + bool TrueIfNeg; + if (isSignBitCheck(Pred, C, TrueIfNeg)) { + // ((X - 1) & ~X) < 0 --> X == 0 + // ((X - 1) & ~X) >= 0 --> X != 0 + Value *X; + if (match(And->getOperand(0), m_Add(m_Value(X), m_AllOnes())) && + match(And->getOperand(1), m_Not(m_Specific(X)))) { + auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; + return new ICmpInst(NewPred, X, ConstantInt::getNullValue(X->getType())); + } + } + + // TODO: These all require that Y is constant too, so refactor with the above. + + // Try to optimize things like "A[i] & 42 == 0" to index computations. + Value *X = And->getOperand(0); + Value *Y = And->getOperand(1); + if (auto *C2 = dyn_cast<ConstantInt>(Y)) + if (auto *LI = dyn_cast<LoadInst>(X)) + if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0))) + if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) + if (Instruction *Res = + foldCmpLoadFromIndexedGlobal(LI, GEP, GV, Cmp, C2)) + return Res; + + if (!Cmp.isEquality()) + return nullptr; + + // X & -C == -C -> X > u ~C + // X & -C != -C -> X <= u ~C + // iff C is a power of 2 + if (Cmp.getOperand(1) == Y && C.isNegatedPowerOf2()) { + auto NewPred = + Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT : CmpInst::ICMP_ULE; + return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); + } + + return nullptr; +} + +/// Fold icmp (or X, Y), C. +Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, + BinaryOperator *Or, + const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (C.isOne()) { + // icmp slt signum(V) 1 --> icmp slt V, 1 + Value *V = nullptr; + if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V)))) + return new ICmpInst(ICmpInst::ICMP_SLT, V, + ConstantInt::get(V->getType(), 1)); + } + + Value *OrOp0 = Or->getOperand(0), *OrOp1 = Or->getOperand(1); + const APInt *MaskC; + if (match(OrOp1, m_APInt(MaskC)) && Cmp.isEquality()) { + if (*MaskC == C && (C + 1).isPowerOf2()) { + // X | C == C --> X <=u C + // X | C != C --> X >u C + // iff C+1 is a power of 2 (C is a bitmask of the low bits) + Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; + return new ICmpInst(Pred, OrOp0, OrOp1); + } + + // More general: canonicalize 'equality with set bits mask' to + // 'equality with clear bits mask'. + // (X | MaskC) == C --> (X & ~MaskC) == C ^ MaskC + // (X | MaskC) != C --> (X & ~MaskC) != C ^ MaskC + if (Or->hasOneUse()) { + Value *And = Builder.CreateAnd(OrOp0, ~(*MaskC)); + Constant *NewC = ConstantInt::get(Or->getType(), C ^ (*MaskC)); + return new ICmpInst(Pred, And, NewC); + } + } + + // (X | (X-1)) s< 0 --> X s< 1 + // (X | (X-1)) s> -1 --> X s> 0 + Value *X; + bool TrueIfSigned; + if (isSignBitCheck(Pred, C, TrueIfSigned) && + match(Or, m_c_Or(m_Add(m_Value(X), m_AllOnes()), m_Deferred(X)))) { + auto NewPred = TrueIfSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGT; + Constant *NewC = ConstantInt::get(X->getType(), TrueIfSigned ? 1 : 0); + return new ICmpInst(NewPred, X, NewC); + } + + if (!Cmp.isEquality() || !C.isZero() || !Or->hasOneUse()) + return nullptr; + + Value *P, *Q; + if (match(Or, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { + // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 + // -> and (icmp eq P, null), (icmp eq Q, null). + Value *CmpP = + Builder.CreateICmp(Pred, P, ConstantInt::getNullValue(P->getType())); + Value *CmpQ = + Builder.CreateICmp(Pred, Q, ConstantInt::getNullValue(Q->getType())); + auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; + return BinaryOperator::Create(BOpc, CmpP, CmpQ); + } + + // Are we using xors to bitwise check for a pair of (in)equalities? Convert to + // a shorter form that has more potential to be folded even further. + Value *X1, *X2, *X3, *X4; + if (match(OrOp0, m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) && + match(OrOp1, m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) { + // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4) + // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4) + Value *Cmp12 = Builder.CreateICmp(Pred, X1, X2); + Value *Cmp34 = Builder.CreateICmp(Pred, X3, X4); + auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; + return BinaryOperator::Create(BOpc, Cmp12, Cmp34); + } + + return nullptr; +} + +/// Fold icmp (mul X, Y), C. +Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, + BinaryOperator *Mul, + const APInt &C) { + const APInt *MulC; + if (!match(Mul->getOperand(1), m_APInt(MulC))) + return nullptr; + + // If this is a test of the sign bit and the multiply is sign-preserving with + // a constant operand, use the multiply LHS operand instead. + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (isSignTest(Pred, C) && Mul->hasNoSignedWrap()) { + if (MulC->isNegative()) + Pred = ICmpInst::getSwappedPredicate(Pred); + return new ICmpInst(Pred, Mul->getOperand(0), + Constant::getNullValue(Mul->getType())); + } + + // If the multiply does not wrap, try to divide the compare constant by the + // multiplication factor. + if (Cmp.isEquality() && !MulC->isZero()) { + // (mul nsw X, MulC) == C --> X == C /s MulC + if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) { + Constant *NewC = ConstantInt::get(Mul->getType(), C.sdiv(*MulC)); + return new ICmpInst(Pred, Mul->getOperand(0), NewC); + } + // (mul nuw X, MulC) == C --> X == C /u MulC + if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) { + Constant *NewC = ConstantInt::get(Mul->getType(), C.udiv(*MulC)); + return new ICmpInst(Pred, Mul->getOperand(0), NewC); + } + } + + return nullptr; +} + +/// Fold icmp (shl 1, Y), C. +static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, + const APInt &C) { + Value *Y; + if (!match(Shl, m_Shl(m_One(), m_Value(Y)))) + return nullptr; + + Type *ShiftType = Shl->getType(); + unsigned TypeBits = C.getBitWidth(); + bool CIsPowerOf2 = C.isPowerOf2(); + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (Cmp.isUnsigned()) { + // (1 << Y) pred C -> Y pred Log2(C) + if (!CIsPowerOf2) { + // (1 << Y) < 30 -> Y <= 4 + // (1 << Y) <= 30 -> Y <= 4 + // (1 << Y) >= 30 -> Y > 4 + // (1 << Y) > 30 -> Y > 4 + if (Pred == ICmpInst::ICMP_ULT) + Pred = ICmpInst::ICMP_ULE; + else if (Pred == ICmpInst::ICMP_UGE) + Pred = ICmpInst::ICMP_UGT; + } + + // (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31 + // (1 << Y) < 2147483648 -> Y < 31 -> Y != 31 + unsigned CLog2 = C.logBase2(); + if (CLog2 == TypeBits - 1) { + if (Pred == ICmpInst::ICMP_UGE) + Pred = ICmpInst::ICMP_EQ; + else if (Pred == ICmpInst::ICMP_ULT) + Pred = ICmpInst::ICMP_NE; + } + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); + } else if (Cmp.isSigned()) { + Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); + if (C.isAllOnes()) { + // (1 << Y) <= -1 -> Y == 31 + if (Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + + // (1 << Y) > -1 -> Y != 31 + if (Pred == ICmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); + } else if (!C) { + // (1 << Y) < 0 -> Y == 31 + // (1 << Y) <= 0 -> Y == 31 + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + + // (1 << Y) >= 0 -> Y != 31 + // (1 << Y) > 0 -> Y != 31 + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); + } + } else if (Cmp.isEquality() && CIsPowerOf2) { + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C.logBase2())); + } + + return nullptr; +} + +/// Fold icmp (shl X, Y), C. +Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, + BinaryOperator *Shl, + const APInt &C) { + const APInt *ShiftVal; + if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal))) + return foldICmpShlConstConst(Cmp, Shl->getOperand(1), C, *ShiftVal); + + const APInt *ShiftAmt; + if (!match(Shl->getOperand(1), m_APInt(ShiftAmt))) + return foldICmpShlOne(Cmp, Shl, C); + + // Check that the shift amount is in range. If not, don't perform undefined + // shifts. When the shift is visited, it will be simplified. + unsigned TypeBits = C.getBitWidth(); + if (ShiftAmt->uge(TypeBits)) + return nullptr; + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Shl->getOperand(0); + Type *ShType = Shl->getType(); + + // NSW guarantees that we are only shifting out sign bits from the high bits, + // so we can ASHR the compare constant without needing a mask and eliminate + // the shift. + if (Shl->hasNoSignedWrap()) { + if (Pred == ICmpInst::ICMP_SGT) { + // icmp Pred (shl nsw X, ShiftAmt), C --> icmp Pred X, (C >>s ShiftAmt) + APInt ShiftedC = C.ashr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && + C.ashr(*ShiftAmt).shl(*ShiftAmt) == C) { + APInt ShiftedC = C.ashr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_SLT) { + // SLE is the same as above, but SLE is canonicalized to SLT, so convert: + // (X << S) <=s C is equiv to X <=s (C >> S) for all C + // (X << S) <s (C + 1) is equiv to X <s (C >> S) + 1 if C <s SMAX + // (X << S) <s C is equiv to X <s ((C - 1) >> S) + 1 if C >s SMIN + assert(!C.isMinSignedValue() && "Unexpected icmp slt"); + APInt ShiftedC = (C - 1).ashr(*ShiftAmt) + 1; + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + // If this is a signed comparison to 0 and the shift is sign preserving, + // use the shift LHS operand instead; isSignTest may change 'Pred', so only + // do that if we're sure to not continue on in this function. + if (isSignTest(Pred, C)) + return new ICmpInst(Pred, X, Constant::getNullValue(ShType)); + } + + // NUW guarantees that we are only shifting out zero bits from the high bits, + // so we can LSHR the compare constant without needing a mask and eliminate + // the shift. + if (Shl->hasNoUnsignedWrap()) { + if (Pred == ICmpInst::ICMP_UGT) { + // icmp Pred (shl nuw X, ShiftAmt), C --> icmp Pred X, (C >>u ShiftAmt) + APInt ShiftedC = C.lshr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && + C.lshr(*ShiftAmt).shl(*ShiftAmt) == C) { + APInt ShiftedC = C.lshr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_ULT) { + // ULE is the same as above, but ULE is canonicalized to ULT, so convert: + // (X << S) <=u C is equiv to X <=u (C >> S) for all C + // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u + // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 + assert(C.ugt(0) && "ult 0 should have been eliminated"); + APInt ShiftedC = (C - 1).lshr(*ShiftAmt) + 1; + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + } + + if (Cmp.isEquality() && Shl->hasOneUse()) { + // Strength-reduce the shift into an 'and'. + Constant *Mask = ConstantInt::get( + ShType, + APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); + Value *And = Builder.CreateAnd(X, Mask, Shl->getName() + ".mask"); + Constant *LShrC = ConstantInt::get(ShType, C.lshr(*ShiftAmt)); + return new ICmpInst(Pred, And, LShrC); + } + + // Otherwise, if this is a comparison of the sign bit, simplify to and/test. + bool TrueIfSigned = false; + if (Shl->hasOneUse() && isSignBitCheck(Pred, C, TrueIfSigned)) { + // (X << 31) <s 0 --> (X & 1) != 0 + Constant *Mask = ConstantInt::get( + ShType, + APInt::getOneBitSet(TypeBits, TypeBits - ShiftAmt->getZExtValue() - 1)); + Value *And = Builder.CreateAnd(X, Mask, Shl->getName() + ".mask"); + return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, + And, Constant::getNullValue(ShType)); + } + + // Simplify 'shl' inequality test into 'and' equality test. + if (Cmp.isUnsigned() && Shl->hasOneUse()) { + // (X l<< C2) u<=/u> C1 iff C1+1 is power of two -> X & (~C1 l>> C2) ==/!= 0 + if ((C + 1).isPowerOf2() && + (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT)) { + Value *And = Builder.CreateAnd(X, (~C).lshr(ShiftAmt->getZExtValue())); + return new ICmpInst(Pred == ICmpInst::ICMP_ULE ? ICmpInst::ICMP_EQ + : ICmpInst::ICMP_NE, + And, Constant::getNullValue(ShType)); + } + // (X l<< C2) u</u>= C1 iff C1 is power of two -> X & (-C1 l>> C2) ==/!= 0 + if (C.isPowerOf2() && + (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) { + Value *And = + Builder.CreateAnd(X, (~(C - 1)).lshr(ShiftAmt->getZExtValue())); + return new ICmpInst(Pred == ICmpInst::ICMP_ULT ? ICmpInst::ICMP_EQ + : ICmpInst::ICMP_NE, + And, Constant::getNullValue(ShType)); + } + } + + // Transform (icmp pred iM (shl iM %v, N), C) + // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (C>>N)) + // Transform the shl to a trunc if (trunc (C>>N)) has no loss and M-N. + // This enables us to get rid of the shift in favor of a trunc that may be + // free on the target. It has the additional benefit of comparing to a + // smaller constant that may be more target-friendly. + unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1); + if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt && + DL.isLegalInteger(TypeBits - Amt)) { + Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); + if (auto *ShVTy = dyn_cast<VectorType>(ShType)) + TruncTy = VectorType::get(TruncTy, ShVTy->getElementCount()); + Constant *NewC = + ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt)); + return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC); + } + + return nullptr; +} + +/// Fold icmp ({al}shr X, Y), C. +Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, + BinaryOperator *Shr, + const APInt &C) { + // An exact shr only shifts out zero bits, so: + // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 + Value *X = Shr->getOperand(0); + CmpInst::Predicate Pred = Cmp.getPredicate(); + if (Cmp.isEquality() && Shr->isExact() && C.isZero()) + return new ICmpInst(Pred, X, Cmp.getOperand(1)); + + bool IsAShr = Shr->getOpcode() == Instruction::AShr; + const APInt *ShiftValC; + if (match(Shr->getOperand(0), m_APInt(ShiftValC))) { + if (Cmp.isEquality()) + return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftValC); + + // If the shifted constant is a power-of-2, test the shift amount directly: + // (ShiftValC >> X) >u C --> X <u (LZ(C) - LZ(ShiftValC)) + // (ShiftValC >> X) <u C --> X >=u (LZ(C-1) - LZ(ShiftValC)) + if (!IsAShr && ShiftValC->isPowerOf2() && + (Pred == CmpInst::ICMP_UGT || Pred == CmpInst::ICMP_ULT)) { + bool IsUGT = Pred == CmpInst::ICMP_UGT; + assert(ShiftValC->uge(C) && "Expected simplify of compare"); + assert((IsUGT || !C.isZero()) && "Expected X u< 0 to simplify"); + + unsigned CmpLZ = + IsUGT ? C.countLeadingZeros() : (C - 1).countLeadingZeros(); + unsigned ShiftLZ = ShiftValC->countLeadingZeros(); + Constant *NewC = ConstantInt::get(Shr->getType(), CmpLZ - ShiftLZ); + auto NewPred = IsUGT ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE; + return new ICmpInst(NewPred, Shr->getOperand(1), NewC); + } + } + + const APInt *ShiftAmtC; + if (!match(Shr->getOperand(1), m_APInt(ShiftAmtC))) + return nullptr; + + // Check that the shift amount is in range. If not, don't perform undefined + // shifts. When the shift is visited it will be simplified. + unsigned TypeBits = C.getBitWidth(); + unsigned ShAmtVal = ShiftAmtC->getLimitedValue(TypeBits); + if (ShAmtVal >= TypeBits || ShAmtVal == 0) + return nullptr; + + bool IsExact = Shr->isExact(); + Type *ShrTy = Shr->getType(); + // TODO: If we could guarantee that InstSimplify would handle all of the + // constant-value-based preconditions in the folds below, then we could assert + // those conditions rather than checking them. This is difficult because of + // undef/poison (PR34838). + if (IsAShr) { + if (IsExact || Pred == CmpInst::ICMP_SLT || Pred == CmpInst::ICMP_ULT) { + // When ShAmtC can be shifted losslessly: + // icmp PRED (ashr exact X, ShAmtC), C --> icmp PRED X, (C << ShAmtC) + // icmp slt/ult (ashr X, ShAmtC), C --> icmp slt/ult X, (C << ShAmtC) + APInt ShiftedC = C.shl(ShAmtVal); + if (ShiftedC.ashr(ShAmtVal) == C) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + if (Pred == CmpInst::ICMP_SGT) { + // icmp sgt (ashr X, ShAmtC), C --> icmp sgt X, ((C + 1) << ShAmtC) - 1 + APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; + if (!C.isMaxSignedValue() && !(C + 1).shl(ShAmtVal).isMinSignedValue() && + (ShiftedC + 1).ashr(ShAmtVal) == (C + 1)) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + if (Pred == CmpInst::ICMP_UGT) { + // icmp ugt (ashr X, ShAmtC), C --> icmp ugt X, ((C + 1) << ShAmtC) - 1 + // 'C + 1 << ShAmtC' can overflow as a signed number, so the 2nd + // clause accounts for that pattern. + APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; + if ((ShiftedC + 1).ashr(ShAmtVal) == (C + 1) || + (C + 1).shl(ShAmtVal).isMinSignedValue()) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + + // If the compare constant has significant bits above the lowest sign-bit, + // then convert an unsigned cmp to a test of the sign-bit: + // (ashr X, ShiftC) u> C --> X s< 0 + // (ashr X, ShiftC) u< C --> X s> -1 + if (C.getBitWidth() > 2 && C.getNumSignBits() <= ShAmtVal) { + if (Pred == CmpInst::ICMP_UGT) { + return new ICmpInst(CmpInst::ICMP_SLT, X, + ConstantInt::getNullValue(ShrTy)); + } + if (Pred == CmpInst::ICMP_ULT) { + return new ICmpInst(CmpInst::ICMP_SGT, X, + ConstantInt::getAllOnesValue(ShrTy)); + } + } + } else { + if (Pred == CmpInst::ICMP_ULT || (Pred == CmpInst::ICMP_UGT && IsExact)) { + // icmp ult (lshr X, ShAmtC), C --> icmp ult X, (C << ShAmtC) + // icmp ugt (lshr exact X, ShAmtC), C --> icmp ugt X, (C << ShAmtC) + APInt ShiftedC = C.shl(ShAmtVal); + if (ShiftedC.lshr(ShAmtVal) == C) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + if (Pred == CmpInst::ICMP_UGT) { + // icmp ugt (lshr X, ShAmtC), C --> icmp ugt X, ((C + 1) << ShAmtC) - 1 + APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; + if ((ShiftedC + 1).lshr(ShAmtVal) == (C + 1)) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + } + + if (!Cmp.isEquality()) + return nullptr; + + // Handle equality comparisons of shift-by-constant. + + // If the comparison constant changes with the shift, the comparison cannot + // succeed (bits of the comparison constant cannot match the shifted value). + // This should be known by InstSimplify and already be folded to true/false. + assert(((IsAShr && C.shl(ShAmtVal).ashr(ShAmtVal) == C) || + (!IsAShr && C.shl(ShAmtVal).lshr(ShAmtVal) == C)) && + "Expected icmp+shr simplify did not occur."); + + // If the bits shifted out are known zero, compare the unshifted value: + // (X & 4) >> 1 == 2 --> (X & 4) == 4. + if (Shr->isExact()) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal)); + + if (C.isZero()) { + // == 0 is u< 1. + if (Pred == CmpInst::ICMP_EQ) + return new ICmpInst(CmpInst::ICMP_ULT, X, + ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal))); + else + return new ICmpInst(CmpInst::ICMP_UGT, X, + ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal) - 1)); + } + + if (Shr->hasOneUse()) { + // Canonicalize the shift into an 'and': + // icmp eq/ne (shr X, ShAmt), C --> icmp eq/ne (and X, HiMask), (C << ShAmt) + APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); + Constant *Mask = ConstantInt::get(ShrTy, Val); + Value *And = Builder.CreateAnd(X, Mask, Shr->getName() + ".mask"); + return new ICmpInst(Pred, And, ConstantInt::get(ShrTy, C << ShAmtVal)); + } + + return nullptr; +} + +Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp, + BinaryOperator *SRem, + const APInt &C) { + // Match an 'is positive' or 'is negative' comparison of remainder by a + // constant power-of-2 value: + // (X % pow2C) sgt/slt 0 + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_SLT && + Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE) + return nullptr; + + // TODO: The one-use check is standard because we do not typically want to + // create longer instruction sequences, but this might be a special-case + // because srem is not good for analysis or codegen. + if (!SRem->hasOneUse()) + return nullptr; + + const APInt *DivisorC; + if (!match(SRem->getOperand(1), m_Power2(DivisorC))) + return nullptr; + + // For cmp_sgt/cmp_slt only zero valued C is handled. + // For cmp_eq/cmp_ne only positive valued C is handled. + if (((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT) && + !C.isZero()) || + ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && + !C.isStrictlyPositive())) + return nullptr; + + // Mask off the sign bit and the modulo bits (low-bits). + Type *Ty = SRem->getType(); + APInt SignMask = APInt::getSignMask(Ty->getScalarSizeInBits()); + Constant *MaskC = ConstantInt::get(Ty, SignMask | (*DivisorC - 1)); + Value *And = Builder.CreateAnd(SRem->getOperand(0), MaskC); + + if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) + return new ICmpInst(Pred, And, ConstantInt::get(Ty, C)); + + // For 'is positive?' check that the sign-bit is clear and at least 1 masked + // bit is set. Example: + // (i8 X % 32) s> 0 --> (X & 159) s> 0 + if (Pred == ICmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_SGT, And, ConstantInt::getNullValue(Ty)); + + // For 'is negative?' check that the sign-bit is set and at least 1 masked + // bit is set. Example: + // (i16 X % 4) s< 0 --> (X & 32771) u> 32768 + return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask)); +} + +/// Fold icmp (udiv X, Y), C. +Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp, + BinaryOperator *UDiv, + const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = UDiv->getOperand(0); + Value *Y = UDiv->getOperand(1); + Type *Ty = UDiv->getType(); + + const APInt *C2; + if (!match(X, m_APInt(C2))) + return nullptr; + + assert(*C2 != 0 && "udiv 0, X should have been simplified already."); + + // (icmp ugt (udiv C2, Y), C) -> (icmp ule Y, C2/(C+1)) + if (Pred == ICmpInst::ICMP_UGT) { + assert(!C.isMaxValue() && + "icmp ugt X, UINT_MAX should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_ULE, Y, + ConstantInt::get(Ty, C2->udiv(C + 1))); + } + + // (icmp ult (udiv C2, Y), C) -> (icmp ugt Y, C2/C) + if (Pred == ICmpInst::ICMP_ULT) { + assert(C != 0 && "icmp ult X, 0 should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_UGT, Y, + ConstantInt::get(Ty, C2->udiv(C))); + } + + return nullptr; +} + +/// Fold icmp ({su}div X, Y), C. +Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, + BinaryOperator *Div, + const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Div->getOperand(0); + Value *Y = Div->getOperand(1); + Type *Ty = Div->getType(); + bool DivIsSigned = Div->getOpcode() == Instruction::SDiv; + + // If unsigned division and the compare constant is bigger than + // UMAX/2 (negative), there's only one pair of values that satisfies an + // equality check, so eliminate the division: + // (X u/ Y) == C --> (X == C) && (Y == 1) + // (X u/ Y) != C --> (X != C) || (Y != 1) + // Similarly, if signed division and the compare constant is exactly SMIN: + // (X s/ Y) == SMIN --> (X == SMIN) && (Y == 1) + // (X s/ Y) != SMIN --> (X != SMIN) || (Y != 1) + if (Cmp.isEquality() && Div->hasOneUse() && C.isSignBitSet() && + (!DivIsSigned || C.isMinSignedValue())) { + Value *XBig = Builder.CreateICmp(Pred, X, ConstantInt::get(Ty, C)); + Value *YOne = Builder.CreateICmp(Pred, Y, ConstantInt::get(Ty, 1)); + auto Logic = Pred == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; + return BinaryOperator::Create(Logic, XBig, YOne); + } + + // Fold: icmp pred ([us]div X, C2), C -> range test + // Fold this div into the comparison, producing a range check. + // Determine, based on the divide type, what the range is being + // checked. If there is an overflow on the low or high side, remember + // it, otherwise compute the range [low, hi) bounding the new value. + // See: InsertRangeTest above for the kinds of replacements possible. + const APInt *C2; + if (!match(Y, m_APInt(C2))) + return nullptr; + + // FIXME: If the operand types don't match the type of the divide + // then don't attempt this transform. The code below doesn't have the + // logic to deal with a signed divide and an unsigned compare (and + // vice versa). This is because (x /s C2) <s C produces different + // results than (x /s C2) <u C or (x /u C2) <s C or even + // (x /u C2) <u C. Simply casting the operands and result won't + // work. :( The if statement below tests that condition and bails + // if it finds it. + if (!Cmp.isEquality() && DivIsSigned != Cmp.isSigned()) + return nullptr; + + // The ProdOV computation fails on divide by 0 and divide by -1. Cases with + // INT_MIN will also fail if the divisor is 1. Although folds of all these + // division-by-constant cases should be present, we can not assert that they + // have happened before we reach this icmp instruction. + if (C2->isZero() || C2->isOne() || (DivIsSigned && C2->isAllOnes())) + return nullptr; + + // Compute Prod = C * C2. We are essentially solving an equation of + // form X / C2 = C. We solve for X by multiplying C2 and C. + // By solving for X, we can turn this into a range check instead of computing + // a divide. + APInt Prod = C * *C2; + + // Determine if the product overflows by seeing if the product is not equal to + // the divide. Make sure we do the same kind of divide as in the LHS + // instruction that we're folding. + bool ProdOV = (DivIsSigned ? Prod.sdiv(*C2) : Prod.udiv(*C2)) != C; + + // If the division is known to be exact, then there is no remainder from the + // divide, so the covered range size is unit, otherwise it is the divisor. + APInt RangeSize = Div->isExact() ? APInt(C2->getBitWidth(), 1) : *C2; + + // Figure out the interval that is being checked. For example, a comparison + // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). + // Compute this interval based on the constants involved and the signedness of + // the compare/divide. This computes a half-open interval, keeping track of + // whether either value in the interval overflows. After analysis each + // overflow variable is set to 0 if it's corresponding bound variable is valid + // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. + int LoOverflow = 0, HiOverflow = 0; + APInt LoBound, HiBound; + + if (!DivIsSigned) { // udiv + // e.g. X/5 op 3 --> [15, 20) + LoBound = Prod; + HiOverflow = LoOverflow = ProdOV; + if (!HiOverflow) { + // If this is not an exact divide, then many values in the range collapse + // to the same result value. + HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false); + } + } else if (C2->isStrictlyPositive()) { // Divisor is > 0. + if (C.isZero()) { // (X / pos) op 0 + // Can't overflow. e.g. X/2 op 0 --> [-1, 2) + LoBound = -(RangeSize - 1); + HiBound = RangeSize; + } else if (C.isStrictlyPositive()) { // (X / pos) op pos + LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) + HiOverflow = LoOverflow = ProdOV; + if (!HiOverflow) + HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); + } else { // (X / pos) op neg + // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) + HiBound = Prod + 1; + LoOverflow = HiOverflow = ProdOV ? -1 : 0; + if (!LoOverflow) { + APInt DivNeg = -RangeSize; + LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; + } + } + } else if (C2->isNegative()) { // Divisor is < 0. + if (Div->isExact()) + RangeSize.negate(); + if (C.isZero()) { // (X / neg) op 0 + // e.g. X/-5 op 0 --> [-4, 5) + LoBound = RangeSize + 1; + HiBound = -RangeSize; + if (HiBound == *C2) { // -INTMIN = INTMIN + HiOverflow = 1; // [INTMIN+1, overflow) + HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN + } + } else if (C.isStrictlyPositive()) { // (X / neg) op pos + // e.g. X/-5 op 3 --> [-19, -14) + HiBound = Prod + 1; + HiOverflow = LoOverflow = ProdOV ? -1 : 0; + if (!LoOverflow) + LoOverflow = + addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1 : 0; + } else { // (X / neg) op neg + LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) + LoOverflow = HiOverflow = ProdOV; + if (!HiOverflow) + HiOverflow = subWithOverflow(HiBound, Prod, RangeSize, true); + } + + // Dividing by a negative swaps the condition. LT <-> GT + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + switch (Pred) { + default: + llvm_unreachable("Unhandled icmp predicate!"); + case ICmpInst::ICMP_EQ: + if (LoOverflow && HiOverflow) + return replaceInstUsesWith(Cmp, Builder.getFalse()); + if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, + X, ConstantInt::get(Ty, LoBound)); + if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + X, ConstantInt::get(Ty, HiBound)); + return replaceInstUsesWith( + Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true)); + case ICmpInst::ICMP_NE: + if (LoOverflow && HiOverflow) + return replaceInstUsesWith(Cmp, Builder.getTrue()); + if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + X, ConstantInt::get(Ty, LoBound)); + if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, + X, ConstantInt::get(Ty, HiBound)); + return replaceInstUsesWith( + Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, false)); + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + if (LoOverflow == +1) // Low bound is greater than input range. + return replaceInstUsesWith(Cmp, Builder.getTrue()); + if (LoOverflow == -1) // Low bound is less than input range. + return replaceInstUsesWith(Cmp, Builder.getFalse()); + return new ICmpInst(Pred, X, ConstantInt::get(Ty, LoBound)); + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + if (HiOverflow == +1) // High bound greater than input range. + return replaceInstUsesWith(Cmp, Builder.getFalse()); + if (HiOverflow == -1) // High bound less than input range. + return replaceInstUsesWith(Cmp, Builder.getTrue()); + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, HiBound)); + return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, HiBound)); + } + + return nullptr; +} + +/// Fold icmp (sub X, Y), C. +Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp, + BinaryOperator *Sub, + const APInt &C) { + Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Type *Ty = Sub->getType(); + + // (SubC - Y) == C) --> Y == (SubC - C) + // (SubC - Y) != C) --> Y != (SubC - C) + Constant *SubC; + if (Cmp.isEquality() && match(X, m_ImmConstant(SubC))) { + return new ICmpInst(Pred, Y, + ConstantExpr::getSub(SubC, ConstantInt::get(Ty, C))); + } + + // (icmp P (sub nuw|nsw C2, Y), C) -> (icmp swap(P) Y, C2-C) + const APInt *C2; + APInt SubResult; + ICmpInst::Predicate SwappedPred = Cmp.getSwappedPredicate(); + bool HasNSW = Sub->hasNoSignedWrap(); + bool HasNUW = Sub->hasNoUnsignedWrap(); + if (match(X, m_APInt(C2)) && + ((Cmp.isUnsigned() && HasNUW) || (Cmp.isSigned() && HasNSW)) && + !subWithOverflow(SubResult, *C2, C, Cmp.isSigned())) + return new ICmpInst(SwappedPred, Y, ConstantInt::get(Ty, SubResult)); + + // X - Y == 0 --> X == Y. + // X - Y != 0 --> X != Y. + // TODO: We allow this with multiple uses as long as the other uses are not + // in phis. The phi use check is guarding against a codegen regression + // for a loop test. If the backend could undo this (and possibly + // subsequent transforms), we would not need this hack. + if (Cmp.isEquality() && C.isZero() && + none_of((Sub->users()), [](const User *U) { return isa<PHINode>(U); })) + return new ICmpInst(Pred, X, Y); + + // The following transforms are only worth it if the only user of the subtract + // is the icmp. + // TODO: This is an artificial restriction for all of the transforms below + // that only need a single replacement icmp. Can these use the phi test + // like the transform above here? + if (!Sub->hasOneUse()) + return nullptr; + + if (Sub->hasNoSignedWrap()) { + // (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y) + if (Pred == ICmpInst::ICMP_SGT && C.isAllOnes()) + return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); + + // (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y) + if (Pred == ICmpInst::ICMP_SGT && C.isZero()) + return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); + + // (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y) + if (Pred == ICmpInst::ICMP_SLT && C.isZero()) + return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); + + // (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y) + if (Pred == ICmpInst::ICMP_SLT && C.isOne()) + return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); + } + + if (!match(X, m_APInt(C2))) + return nullptr; + + // C2 - Y <u C -> (Y | (C - 1)) == C2 + // iff (C2 & (C - 1)) == C - 1 and C is a power of 2 + if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() && + (*C2 & (C - 1)) == (C - 1)) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateOr(Y, C - 1), X); + + // C2 - Y >u C -> (Y | C) != C2 + // iff C2 & C == C and C + 1 is a power of 2 + if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == C) + return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, C), X); + + // We have handled special cases that reduce. + // Canonicalize any remaining sub to add as: + // (C2 - Y) > C --> (Y + ~C2) < ~C + Value *Add = Builder.CreateAdd(Y, ConstantInt::get(Ty, ~(*C2)), "notsub", + HasNUW, HasNSW); + return new ICmpInst(SwappedPred, Add, ConstantInt::get(Ty, ~C)); +} + +/// Fold icmp (add X, Y), C. +Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, + BinaryOperator *Add, + const APInt &C) { + Value *Y = Add->getOperand(1); + const APInt *C2; + if (Cmp.isEquality() || !match(Y, m_APInt(C2))) + return nullptr; + + // Fold icmp pred (add X, C2), C. + Value *X = Add->getOperand(0); + Type *Ty = Add->getType(); + const CmpInst::Predicate Pred = Cmp.getPredicate(); + + // If the add does not wrap, we can always adjust the compare by subtracting + // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE + // are canonicalized to SGT/SLT/UGT/ULT. + if ((Add->hasNoSignedWrap() && + (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) || + (Add->hasNoUnsignedWrap() && + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT))) { + bool Overflow; + APInt NewC = + Cmp.isSigned() ? C.ssub_ov(*C2, Overflow) : C.usub_ov(*C2, Overflow); + // If there is overflow, the result must be true or false. + // TODO: Can we assert there is no overflow because InstSimplify always + // handles those cases? + if (!Overflow) + // icmp Pred (add nsw X, C2), C --> icmp Pred X, (C - C2) + return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC)); + } + + auto CR = ConstantRange::makeExactICmpRegion(Pred, C).subtract(*C2); + const APInt &Upper = CR.getUpper(); + const APInt &Lower = CR.getLower(); + if (Cmp.isSigned()) { + if (Lower.isSignMask()) + return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, Upper)); + if (Upper.isSignMask()) + return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, Lower)); + } else { + if (Lower.isMinValue()) + return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantInt::get(Ty, Upper)); + if (Upper.isMinValue()) + return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower)); + } + + // This set of folds is intentionally placed after folds that use no-wrapping + // flags because those folds are likely better for later analysis/codegen. + const APInt SMax = APInt::getSignedMaxValue(Ty->getScalarSizeInBits()); + const APInt SMin = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); + + // Fold compare with offset to opposite sign compare if it eliminates offset: + // (X + C2) >u C --> X <s -C2 (if C == C2 + SMAX) + if (Pred == CmpInst::ICMP_UGT && C == *C2 + SMax) + return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, -(*C2))); + + // (X + C2) <u C --> X >s ~C2 (if C == C2 + SMIN) + if (Pred == CmpInst::ICMP_ULT && C == *C2 + SMin) + return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantInt::get(Ty, ~(*C2))); + + // (X + C2) >s C --> X <u (SMAX - C) (if C == C2 - 1) + if (Pred == CmpInst::ICMP_SGT && C == *C2 - 1) + return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantInt::get(Ty, SMax - C)); + + // (X + C2) <s C --> X >u (C ^ SMAX) (if C == C2) + if (Pred == CmpInst::ICMP_SLT && C == *C2) + return new ICmpInst(ICmpInst::ICMP_UGT, X, ConstantInt::get(Ty, C ^ SMax)); + + if (!Add->hasOneUse()) + return nullptr; + + // X+C <u C2 -> (X & -C2) == C + // iff C & (C2-1) == 0 + // C2 is a power of 2 + if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() && (*C2 & (C - 1)) == 0) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -C), + ConstantExpr::getNeg(cast<Constant>(Y))); + + // X+C >u C2 -> (X & ~C2) != C + // iff C & C2 == 0 + // C2+1 is a power of 2 + if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == 0) + return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~C), + ConstantExpr::getNeg(cast<Constant>(Y))); + + // The range test idiom can use either ult or ugt. Arbitrarily canonicalize + // to the ult form. + // X+C2 >u C -> X+(C2-C-1) <u ~C + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_ULT, + Builder.CreateAdd(X, ConstantInt::get(Ty, *C2 - C - 1)), + ConstantInt::get(Ty, ~C)); + + return nullptr; +} + +bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, + Value *&RHS, ConstantInt *&Less, + ConstantInt *&Equal, + ConstantInt *&Greater) { + // TODO: Generalize this to work with other comparison idioms or ensure + // they get canonicalized into this form. + + // select i1 (a == b), + // i32 Equal, + // i32 (select i1 (a < b), i32 Less, i32 Greater) + // where Equal, Less and Greater are placeholders for any three constants. + ICmpInst::Predicate PredA; + if (!match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) || + !ICmpInst::isEquality(PredA)) + return false; + Value *EqualVal = SI->getTrueValue(); + Value *UnequalVal = SI->getFalseValue(); + // We still can get non-canonical predicate here, so canonicalize. + if (PredA == ICmpInst::ICMP_NE) + std::swap(EqualVal, UnequalVal); + if (!match(EqualVal, m_ConstantInt(Equal))) + return false; + ICmpInst::Predicate PredB; + Value *LHS2, *RHS2; + if (!match(UnequalVal, m_Select(m_ICmp(PredB, m_Value(LHS2), m_Value(RHS2)), + m_ConstantInt(Less), m_ConstantInt(Greater)))) + return false; + // We can get predicate mismatch here, so canonicalize if possible: + // First, ensure that 'LHS' match. + if (LHS2 != LHS) { + // x sgt y <--> y slt x + std::swap(LHS2, RHS2); + PredB = ICmpInst::getSwappedPredicate(PredB); + } + if (LHS2 != LHS) + return false; + // We also need to canonicalize 'RHS'. + if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) { + // x sgt C-1 <--> x sge C <--> not(x slt C) + auto FlippedStrictness = + InstCombiner::getFlippedStrictnessPredicateAndConstant( + PredB, cast<Constant>(RHS2)); + if (!FlippedStrictness) + return false; + assert(FlippedStrictness->first == ICmpInst::ICMP_SGE && + "basic correctness failure"); + RHS2 = FlippedStrictness->second; + // And kind-of perform the result swap. + std::swap(Less, Greater); + PredB = ICmpInst::ICMP_SLT; + } + return PredB == ICmpInst::ICMP_SLT && RHS == RHS2; +} + +Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp, + SelectInst *Select, + ConstantInt *C) { + + assert(C && "Cmp RHS should be a constant int!"); + // If we're testing a constant value against the result of a three way + // comparison, the result can be expressed directly in terms of the + // original values being compared. Note: We could possibly be more + // aggressive here and remove the hasOneUse test. The original select is + // really likely to simplify or sink when we remove a test of the result. + Value *OrigLHS, *OrigRHS; + ConstantInt *C1LessThan, *C2Equal, *C3GreaterThan; + if (Cmp.hasOneUse() && + matchThreeWayIntCompare(Select, OrigLHS, OrigRHS, C1LessThan, C2Equal, + C3GreaterThan)) { + assert(C1LessThan && C2Equal && C3GreaterThan); + + bool TrueWhenLessThan = + ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C) + ->isAllOnesValue(); + bool TrueWhenEqual = + ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C) + ->isAllOnesValue(); + bool TrueWhenGreaterThan = + ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C) + ->isAllOnesValue(); + + // This generates the new instruction that will replace the original Cmp + // Instruction. Instead of enumerating the various combinations when + // TrueWhenLessThan, TrueWhenEqual and TrueWhenGreaterThan are true versus + // false, we rely on chaining of ORs and future passes of InstCombine to + // simplify the OR further (i.e. a s< b || a == b becomes a s<= b). + + // When none of the three constants satisfy the predicate for the RHS (C), + // the entire original Cmp can be simplified to a false. + Value *Cond = Builder.getFalse(); + if (TrueWhenLessThan) + Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SLT, + OrigLHS, OrigRHS)); + if (TrueWhenEqual) + Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_EQ, + OrigLHS, OrigRHS)); + if (TrueWhenGreaterThan) + Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SGT, + OrigLHS, OrigRHS)); + + return replaceInstUsesWith(Cmp, Cond); + } + return nullptr; +} + +Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { + auto *Bitcast = dyn_cast<BitCastInst>(Cmp.getOperand(0)); + if (!Bitcast) + return nullptr; + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op1 = Cmp.getOperand(1); + Value *BCSrcOp = Bitcast->getOperand(0); + Type *SrcType = Bitcast->getSrcTy(); + Type *DstType = Bitcast->getType(); + + // Make sure the bitcast doesn't change between scalar and vector and + // doesn't change the number of vector elements. + if (SrcType->isVectorTy() == DstType->isVectorTy() && + SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits()) { + // Zero-equality and sign-bit checks are preserved through sitofp + bitcast. + Value *X; + if (match(BCSrcOp, m_SIToFP(m_Value(X)))) { + // icmp eq (bitcast (sitofp X)), 0 --> icmp eq X, 0 + // icmp ne (bitcast (sitofp X)), 0 --> icmp ne X, 0 + // icmp slt (bitcast (sitofp X)), 0 --> icmp slt X, 0 + // icmp sgt (bitcast (sitofp X)), 0 --> icmp sgt X, 0 + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT || + Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) && + match(Op1, m_Zero())) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + + // icmp slt (bitcast (sitofp X)), 1 --> icmp slt X, 1 + if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One())) + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1)); + + // icmp sgt (bitcast (sitofp X)), -1 --> icmp sgt X, -1 + if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes())) + return new ICmpInst(Pred, X, + ConstantInt::getAllOnesValue(X->getType())); + } + + // Zero-equality checks are preserved through unsigned floating-point casts: + // icmp eq (bitcast (uitofp X)), 0 --> icmp eq X, 0 + // icmp ne (bitcast (uitofp X)), 0 --> icmp ne X, 0 + if (match(BCSrcOp, m_UIToFP(m_Value(X)))) + if (Cmp.isEquality() && match(Op1, m_Zero())) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + + // If this is a sign-bit test of a bitcast of a casted FP value, eliminate + // the FP extend/truncate because that cast does not change the sign-bit. + // This is true for all standard IEEE-754 types and the X86 80-bit type. + // The sign-bit is always the most significant bit in those types. + const APInt *C; + bool TrueIfSigned; + if (match(Op1, m_APInt(C)) && Bitcast->hasOneUse() && + InstCombiner::isSignBitCheck(Pred, *C, TrueIfSigned)) { + if (match(BCSrcOp, m_FPExt(m_Value(X))) || + match(BCSrcOp, m_FPTrunc(m_Value(X)))) { + // (bitcast (fpext/fptrunc X)) to iX) < 0 --> (bitcast X to iY) < 0 + // (bitcast (fpext/fptrunc X)) to iX) > -1 --> (bitcast X to iY) > -1 + Type *XType = X->getType(); + + // We can't currently handle Power style floating point operations here. + if (!(XType->isPPC_FP128Ty() || SrcType->isPPC_FP128Ty())) { + Type *NewType = Builder.getIntNTy(XType->getScalarSizeInBits()); + if (auto *XVTy = dyn_cast<VectorType>(XType)) + NewType = VectorType::get(NewType, XVTy->getElementCount()); + Value *NewBitcast = Builder.CreateBitCast(X, NewType); + if (TrueIfSigned) + return new ICmpInst(ICmpInst::ICMP_SLT, NewBitcast, + ConstantInt::getNullValue(NewType)); + else + return new ICmpInst(ICmpInst::ICMP_SGT, NewBitcast, + ConstantInt::getAllOnesValue(NewType)); + } + } + } + } + + // Test to see if the operands of the icmp are casted versions of other + // values. If the ptr->ptr cast can be stripped off both arguments, do so. + if (DstType->isPointerTy() && (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) { + // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast + // so eliminate it as well. + if (auto *BC2 = dyn_cast<BitCastInst>(Op1)) + Op1 = BC2->getOperand(0); + + Op1 = Builder.CreateBitCast(Op1, SrcType); + return new ICmpInst(Pred, BCSrcOp, Op1); + } + + const APInt *C; + if (!match(Cmp.getOperand(1), m_APInt(C)) || !DstType->isIntegerTy() || + !SrcType->isIntOrIntVectorTy()) + return nullptr; + + // If this is checking if all elements of a vector compare are set or not, + // invert the casted vector equality compare and test if all compare + // elements are clear or not. Compare against zero is generally easier for + // analysis and codegen. + // icmp eq/ne (bitcast (not X) to iN), -1 --> icmp eq/ne (bitcast X to iN), 0 + // Example: are all elements equal? --> are zero elements not equal? + // TODO: Try harder to reduce compare of 2 freely invertible operands? + if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse() && + isFreeToInvert(BCSrcOp, BCSrcOp->hasOneUse())) { + Value *Cast = Builder.CreateBitCast(Builder.CreateNot(BCSrcOp), DstType); + return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(DstType)); + } + + // If this is checking if all elements of an extended vector are clear or not, + // compare in a narrow type to eliminate the extend: + // icmp eq/ne (bitcast (ext X) to iN), 0 --> icmp eq/ne (bitcast X to iM), 0 + Value *X; + if (Cmp.isEquality() && C->isZero() && Bitcast->hasOneUse() && + match(BCSrcOp, m_ZExtOrSExt(m_Value(X)))) { + if (auto *VecTy = dyn_cast<FixedVectorType>(X->getType())) { + Type *NewType = Builder.getIntNTy(VecTy->getPrimitiveSizeInBits()); + Value *NewCast = Builder.CreateBitCast(X, NewType); + return new ICmpInst(Pred, NewCast, ConstantInt::getNullValue(NewType)); + } + } + + // Folding: icmp <pred> iN X, C + // where X = bitcast <M x iK> (shufflevector <M x iK> %vec, undef, SC)) to iN + // and C is a splat of a K-bit pattern + // and SC is a constant vector = <C', C', C', ..., C'> + // Into: + // %E = extractelement <M x iK> %vec, i32 C' + // icmp <pred> iK %E, trunc(C) + Value *Vec; + ArrayRef<int> Mask; + if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) { + // Check whether every element of Mask is the same constant + if (is_splat(Mask)) { + auto *VecTy = cast<VectorType>(SrcType); + auto *EltTy = cast<IntegerType>(VecTy->getElementType()); + if (C->isSplat(EltTy->getBitWidth())) { + // Fold the icmp based on the value of C + // If C is M copies of an iK sized bit pattern, + // then: + // => %E = extractelement <N x iK> %vec, i32 Elem + // icmp <pred> iK %SplatVal, <pattern> + Value *Elem = Builder.getInt32(Mask[0]); + Value *Extract = Builder.CreateExtractElement(Vec, Elem); + Value *NewC = ConstantInt::get(EltTy, C->trunc(EltTy->getBitWidth())); + return new ICmpInst(Pred, Extract, NewC); + } + } + } + return nullptr; +} + +/// Try to fold integer comparisons with a constant operand: icmp Pred X, C +/// where X is some kind of instruction. +Instruction *InstCombinerImpl::foldICmpInstWithConstant(ICmpInst &Cmp) { + const APInt *C; + + if (match(Cmp.getOperand(1), m_APInt(C))) { + if (auto *BO = dyn_cast<BinaryOperator>(Cmp.getOperand(0))) + if (Instruction *I = foldICmpBinOpWithConstant(Cmp, BO, *C)) + return I; + + if (auto *SI = dyn_cast<SelectInst>(Cmp.getOperand(0))) + // For now, we only support constant integers while folding the + // ICMP(SELECT)) pattern. We can extend this to support vector of integers + // similar to the cases handled by binary ops above. + if (auto *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1))) + if (Instruction *I = foldICmpSelectConstant(Cmp, SI, ConstRHS)) + return I; + + if (auto *TI = dyn_cast<TruncInst>(Cmp.getOperand(0))) + if (Instruction *I = foldICmpTruncConstant(Cmp, TI, *C)) + return I; + + if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) + if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, II, *C)) + return I; + } + + if (match(Cmp.getOperand(1), m_APIntAllowUndef(C))) + return foldICmpInstWithConstantAllowUndef(Cmp, *C); + + return nullptr; +} + +/// Fold an icmp equality instruction with binary operator LHS and constant RHS: +/// icmp eq/ne BO, C. +Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( + ICmpInst &Cmp, BinaryOperator *BO, const APInt &C) { + // TODO: Some of these folds could work with arbitrary constants, but this + // function is limited to scalar and vector splat constants. + if (!Cmp.isEquality()) + return nullptr; + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + bool isICMP_NE = Pred == ICmpInst::ICMP_NE; + Constant *RHS = cast<Constant>(Cmp.getOperand(1)); + Value *BOp0 = BO->getOperand(0), *BOp1 = BO->getOperand(1); + + switch (BO->getOpcode()) { + case Instruction::SRem: + // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. + if (C.isZero() && BO->hasOneUse()) { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { + Value *NewRem = Builder.CreateURem(BOp0, BOp1, BO->getName()); + return new ICmpInst(Pred, NewRem, + Constant::getNullValue(BO->getType())); + } + } + break; + case Instruction::Add: { + // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. + if (Constant *BOC = dyn_cast<Constant>(BOp1)) { + if (BO->hasOneUse()) + return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, BOC)); + } else if (C.isZero()) { + // Replace ((add A, B) != 0) with (A != -B) if A or B is + // efficiently invertible, or if the add has just this one use. + if (Value *NegVal = dyn_castNegVal(BOp1)) + return new ICmpInst(Pred, BOp0, NegVal); + if (Value *NegVal = dyn_castNegVal(BOp0)) + return new ICmpInst(Pred, NegVal, BOp1); + if (BO->hasOneUse()) { + Value *Neg = Builder.CreateNeg(BOp1); + Neg->takeName(BO); + return new ICmpInst(Pred, BOp0, Neg); + } + } + break; + } + case Instruction::Xor: + if (BO->hasOneUse()) { + if (Constant *BOC = dyn_cast<Constant>(BOp1)) { + // For the xor case, we can xor two constants together, eliminating + // the explicit xor. + return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); + } else if (C.isZero()) { + // Replace ((xor A, B) != 0) with (A != B) + return new ICmpInst(Pred, BOp0, BOp1); + } + } + break; + case Instruction::Or: { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC)) && BO->hasOneUse() && RHS->isAllOnesValue()) { + // Comparing if all bits outside of a constant mask are set? + // Replace (X | C) == -1 with (X & ~C) == ~C. + // This removes the -1 constant. + Constant *NotBOC = ConstantExpr::getNot(cast<Constant>(BOp1)); + Value *And = Builder.CreateAnd(BOp0, NotBOC); + return new ICmpInst(Pred, And, NotBOC); + } + break; + } + case Instruction::And: { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC))) { + // If we have ((X & C) == C), turn it into ((X & C) != 0). + if (C == *BOC && C.isPowerOf2()) + return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, + BO, Constant::getNullValue(RHS->getType())); + } + break; + } + case Instruction::UDiv: + if (C.isZero()) { + // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) + auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; + return new ICmpInst(NewPred, BOp1, BOp0); + } + break; + default: + break; + } + return nullptr; +} + +/// Fold an equality icmp with LLVM intrinsic and constant operand. +Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( + ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { + Type *Ty = II->getType(); + unsigned BitWidth = C.getBitWidth(); + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + + switch (II->getIntrinsicID()) { + case Intrinsic::abs: + // abs(A) == 0 -> A == 0 + // abs(A) == INT_MIN -> A == INT_MIN + if (C.isZero() || C.isMinSignedValue()) + return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C)); + break; + + case Intrinsic::bswap: + // bswap(A) == C -> A == bswap(C) + return new ICmpInst(Pred, II->getArgOperand(0), + ConstantInt::get(Ty, C.byteSwap())); + + case Intrinsic::ctlz: + case Intrinsic::cttz: { + // ctz(A) == bitwidth(A) -> A == 0 and likewise for != + if (C == BitWidth) + return new ICmpInst(Pred, II->getArgOperand(0), + ConstantInt::getNullValue(Ty)); + + // ctz(A) == C -> A & Mask1 == Mask2, where Mask2 only has bit C set + // and Mask1 has bits 0..C+1 set. Similar for ctl, but for high bits. + // Limit to one use to ensure we don't increase instruction count. + unsigned Num = C.getLimitedValue(BitWidth); + if (Num != BitWidth && II->hasOneUse()) { + bool IsTrailing = II->getIntrinsicID() == Intrinsic::cttz; + APInt Mask1 = IsTrailing ? APInt::getLowBitsSet(BitWidth, Num + 1) + : APInt::getHighBitsSet(BitWidth, Num + 1); + APInt Mask2 = IsTrailing + ? APInt::getOneBitSet(BitWidth, Num) + : APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); + return new ICmpInst(Pred, Builder.CreateAnd(II->getArgOperand(0), Mask1), + ConstantInt::get(Ty, Mask2)); + } + break; + } + + case Intrinsic::ctpop: { + // popcount(A) == 0 -> A == 0 and likewise for != + // popcount(A) == bitwidth(A) -> A == -1 and likewise for != + bool IsZero = C.isZero(); + if (IsZero || C == BitWidth) + return new ICmpInst(Pred, II->getArgOperand(0), + IsZero ? Constant::getNullValue(Ty) + : Constant::getAllOnesValue(Ty)); + + break; + } + + case Intrinsic::fshl: + case Intrinsic::fshr: + if (II->getArgOperand(0) == II->getArgOperand(1)) { + const APInt *RotAmtC; + // ror(X, RotAmtC) == C --> X == rol(C, RotAmtC) + // rol(X, RotAmtC) == C --> X == ror(C, RotAmtC) + if (match(II->getArgOperand(2), m_APInt(RotAmtC))) + return new ICmpInst(Pred, II->getArgOperand(0), + II->getIntrinsicID() == Intrinsic::fshl + ? ConstantInt::get(Ty, C.rotr(*RotAmtC)) + : ConstantInt::get(Ty, C.rotl(*RotAmtC))); + } + break; + + case Intrinsic::uadd_sat: { + // uadd.sat(a, b) == 0 -> (a | b) == 0 + if (C.isZero()) { + Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1)); + return new ICmpInst(Pred, Or, Constant::getNullValue(Ty)); + } + break; + } + + case Intrinsic::usub_sat: { + // usub.sat(a, b) == 0 -> a <= b + if (C.isZero()) { + ICmpInst::Predicate NewPred = + Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; + return new ICmpInst(NewPred, II->getArgOperand(0), II->getArgOperand(1)); + } + break; + } + default: + break; + } + + return nullptr; +} + +/// Fold an icmp with LLVM intrinsics +static Instruction *foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp) { + assert(Cmp.isEquality()); + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op0 = Cmp.getOperand(0); + Value *Op1 = Cmp.getOperand(1); + const auto *IIOp0 = dyn_cast<IntrinsicInst>(Op0); + const auto *IIOp1 = dyn_cast<IntrinsicInst>(Op1); + if (!IIOp0 || !IIOp1 || IIOp0->getIntrinsicID() != IIOp1->getIntrinsicID()) + return nullptr; + + switch (IIOp0->getIntrinsicID()) { + case Intrinsic::bswap: + case Intrinsic::bitreverse: + // If both operands are byte-swapped or bit-reversed, just compare the + // original values. + return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); + case Intrinsic::fshl: + case Intrinsic::fshr: + // If both operands are rotated by same amount, just compare the + // original values. + if (IIOp0->getOperand(0) != IIOp0->getOperand(1)) + break; + if (IIOp1->getOperand(0) != IIOp1->getOperand(1)) + break; + if (IIOp0->getOperand(2) != IIOp1->getOperand(2)) + break; + return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); + default: + break; + } + + return nullptr; +} + +/// Try to fold integer comparisons with a constant operand: icmp Pred X, C +/// where X is some kind of instruction and C is AllowUndef. +/// TODO: Move more folds which allow undef to this function. +Instruction * +InstCombinerImpl::foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, + const APInt &C) { + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) { + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::fshl: + case Intrinsic::fshr: + if (Cmp.isEquality() && II->getArgOperand(0) == II->getArgOperand(1)) { + // (rot X, ?) == 0/-1 --> X == 0/-1 + if (C.isZero() || C.isAllOnes()) + return new ICmpInst(Pred, II->getArgOperand(0), Cmp.getOperand(1)); + } + break; + } + } + + return nullptr; +} + +/// Fold an icmp with BinaryOp and constant operand: icmp Pred BO, C. +Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, + BinaryOperator *BO, + const APInt &C) { + switch (BO->getOpcode()) { + case Instruction::Xor: + if (Instruction *I = foldICmpXorConstant(Cmp, BO, C)) + return I; + break; + case Instruction::And: + if (Instruction *I = foldICmpAndConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Or: + if (Instruction *I = foldICmpOrConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Mul: + if (Instruction *I = foldICmpMulConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Shl: + if (Instruction *I = foldICmpShlConstant(Cmp, BO, C)) + return I; + break; + case Instruction::LShr: + case Instruction::AShr: + if (Instruction *I = foldICmpShrConstant(Cmp, BO, C)) + return I; + break; + case Instruction::SRem: + if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C)) + return I; + break; + case Instruction::UDiv: + if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C)) + return I; + LLVM_FALLTHROUGH; + case Instruction::SDiv: + if (Instruction *I = foldICmpDivConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Sub: + if (Instruction *I = foldICmpSubConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Add: + if (Instruction *I = foldICmpAddConstant(Cmp, BO, C)) + return I; + break; + default: + break; + } + + // TODO: These folds could be refactored to be part of the above calls. + return foldICmpBinOpEqualityWithConstant(Cmp, BO, C); +} + +/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. +Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, + IntrinsicInst *II, + const APInt &C) { + if (Cmp.isEquality()) + return foldICmpEqIntrinsicWithConstant(Cmp, II, C); + + Type *Ty = II->getType(); + unsigned BitWidth = C.getBitWidth(); + ICmpInst::Predicate Pred = Cmp.getPredicate(); + switch (II->getIntrinsicID()) { + case Intrinsic::ctpop: { + // (ctpop X > BitWidth - 1) --> X == -1 + Value *X = II->getArgOperand(0); + if (C == BitWidth - 1 && Pred == ICmpInst::ICMP_UGT) + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, X, + ConstantInt::getAllOnesValue(Ty)); + // (ctpop X < BitWidth) --> X != -1 + if (C == BitWidth && Pred == ICmpInst::ICMP_ULT) + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, X, + ConstantInt::getAllOnesValue(Ty)); + break; + } + case Intrinsic::ctlz: { + // ctlz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX < 0b00010000 + if (Pred == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { + unsigned Num = C.getLimitedValue(); + APInt Limit = APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, + II->getArgOperand(0), ConstantInt::get(Ty, Limit)); + } + + // ctlz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX > 0b00011111 + if (Pred == ICmpInst::ICMP_ULT && C.uge(1) && C.ule(BitWidth)) { + unsigned Num = C.getLimitedValue(); + APInt Limit = APInt::getLowBitsSet(BitWidth, BitWidth - Num); + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, + II->getArgOperand(0), ConstantInt::get(Ty, Limit)); + } + break; + } + case Intrinsic::cttz: { + // Limit to one use to ensure we don't increase instruction count. + if (!II->hasOneUse()) + return nullptr; + + // cttz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX & 0b00001111 == 0 + if (Pred == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { + APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue() + 1); + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, + Builder.CreateAnd(II->getArgOperand(0), Mask), + ConstantInt::getNullValue(Ty)); + } + + // cttz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX & 0b00000111 != 0 + if (Pred == ICmpInst::ICMP_ULT && C.uge(1) && C.ule(BitWidth)) { + APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue()); + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, + Builder.CreateAnd(II->getArgOperand(0), Mask), + ConstantInt::getNullValue(Ty)); + } + break; + } + default: + break; + } + + return nullptr; +} + +/// Handle icmp with constant (but not simple integer constant) RHS. +Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Constant *RHSC = dyn_cast<Constant>(Op1); + Instruction *LHSI = dyn_cast<Instruction>(Op0); + if (!RHSC || !LHSI) + return nullptr; + + switch (LHSI->getOpcode()) { + case Instruction::GetElementPtr: + // icmp pred GEP (P, int 0, int 0, int 0), null -> icmp pred P, null + if (RHSC->isNullValue() && + cast<GetElementPtrInst>(LHSI)->hasAllZeroIndices()) + return new ICmpInst( + I.getPredicate(), LHSI->getOperand(0), + Constant::getNullValue(LHSI->getOperand(0)->getType())); + break; + case Instruction::PHI: + // Only fold icmp into the PHI if the phi and icmp are in the same + // block. If in the same block, we're encouraging jump threading. If + // not, we are just pessimizing the code by making an i1 phi. + if (LHSI->getParent() == I.getParent()) + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) + return NV; + break; + case Instruction::IntToPtr: + // icmp pred inttoptr(X), null -> icmp pred X, 0 + if (RHSC->isNullValue() && + DL.getIntPtrType(RHSC->getType()) == LHSI->getOperand(0)->getType()) + return new ICmpInst( + I.getPredicate(), LHSI->getOperand(0), + Constant::getNullValue(LHSI->getOperand(0)->getType())); + break; + + case Instruction::Load: + // Try to optimize things like "A[i] > 4" to index computations. + if (GetElementPtrInst *GEP = + dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) + if (Instruction *Res = + foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, GV, I)) + return Res; + break; + } + + return nullptr; +} + +Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred, + SelectInst *SI, Value *RHS, + const ICmpInst &I) { + // Try to fold the comparison into the select arms, which will cause the + // select to be converted into a logical and/or. + auto SimplifyOp = [&](Value *Op, bool SelectCondIsTrue) -> Value * { + if (Value *Res = simplifyICmpInst(Pred, Op, RHS, SQ)) + return Res; + if (Optional<bool> Impl = isImpliedCondition(SI->getCondition(), Pred, Op, + RHS, DL, SelectCondIsTrue)) + return ConstantInt::get(I.getType(), *Impl); + return nullptr; + }; + + ConstantInt *CI = nullptr; + Value *Op1 = SimplifyOp(SI->getOperand(1), true); + if (Op1) + CI = dyn_cast<ConstantInt>(Op1); + + Value *Op2 = SimplifyOp(SI->getOperand(2), false); + if (Op2) + CI = dyn_cast<ConstantInt>(Op2); + + // We only want to perform this transformation if it will not lead to + // additional code. This is true if either both sides of the select + // fold to a constant (in which case the icmp is replaced with a select + // which will usually simplify) or this is the only user of the + // select (in which case we are trading a select+icmp for a simpler + // select+icmp) or all uses of the select can be replaced based on + // dominance information ("Global cases"). + bool Transform = false; + if (Op1 && Op2) + Transform = true; + else if (Op1 || Op2) { + // Local case + if (SI->hasOneUse()) + Transform = true; + // Global cases + else if (CI && !CI->isZero()) + // When Op1 is constant try replacing select with second operand. + // Otherwise Op2 is constant and try replacing select with first + // operand. + Transform = replacedSelectWithOperand(SI, &I, Op1 ? 2 : 1); + } + if (Transform) { + if (!Op1) + Op1 = Builder.CreateICmp(Pred, SI->getOperand(1), RHS, I.getName()); + if (!Op2) + Op2 = Builder.CreateICmp(Pred, SI->getOperand(2), RHS, I.getName()); + return SelectInst::Create(SI->getOperand(0), Op1, Op2); + } + + return nullptr; +} + +/// Some comparisons can be simplified. +/// In this case, we are looking for comparisons that look like +/// a check for a lossy truncation. +/// Folds: +/// icmp SrcPred (x & Mask), x to icmp DstPred x, Mask +/// Where Mask is some pattern that produces all-ones in low bits: +/// (-1 >> y) +/// ((-1 << y) >> y) <- non-canonical, has extra uses +/// ~(-1 << y) +/// ((1 << y) + (-1)) <- non-canonical, has extra uses +/// The Mask can be a constant, too. +/// For some predicates, the operands are commutative. +/// For others, x can only be on a specific side. +static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate SrcPred; + Value *X, *M, *Y; + auto m_VariableMask = m_CombineOr( + m_CombineOr(m_Not(m_Shl(m_AllOnes(), m_Value())), + m_Add(m_Shl(m_One(), m_Value()), m_AllOnes())), + m_CombineOr(m_LShr(m_AllOnes(), m_Value()), + m_LShr(m_Shl(m_AllOnes(), m_Value(Y)), m_Deferred(Y)))); + auto m_Mask = m_CombineOr(m_VariableMask, m_LowBitMask()); + if (!match(&I, m_c_ICmp(SrcPred, + m_c_And(m_CombineAnd(m_Mask, m_Value(M)), m_Value(X)), + m_Deferred(X)))) + return nullptr; + + ICmpInst::Predicate DstPred; + switch (SrcPred) { + case ICmpInst::Predicate::ICMP_EQ: + // x & (-1 >> y) == x -> x u<= (-1 >> y) + DstPred = ICmpInst::Predicate::ICMP_ULE; + break; + case ICmpInst::Predicate::ICMP_NE: + // x & (-1 >> y) != x -> x u> (-1 >> y) + DstPred = ICmpInst::Predicate::ICMP_UGT; + break; + case ICmpInst::Predicate::ICMP_ULT: + // x & (-1 >> y) u< x -> x u> (-1 >> y) + // x u> x & (-1 >> y) -> x u> (-1 >> y) + DstPred = ICmpInst::Predicate::ICMP_UGT; + break; + case ICmpInst::Predicate::ICMP_UGE: + // x & (-1 >> y) u>= x -> x u<= (-1 >> y) + // x u<= x & (-1 >> y) -> x u<= (-1 >> y) + DstPred = ICmpInst::Predicate::ICMP_ULE; + break; + case ICmpInst::Predicate::ICMP_SLT: + // x & (-1 >> y) s< x -> x s> (-1 >> y) + // x s> x & (-1 >> y) -> x s> (-1 >> y) + if (!match(M, m_Constant())) // Can not do this fold with non-constant. + return nullptr; + if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. + return nullptr; + DstPred = ICmpInst::Predicate::ICMP_SGT; + break; + case ICmpInst::Predicate::ICMP_SGE: + // x & (-1 >> y) s>= x -> x s<= (-1 >> y) + // x s<= x & (-1 >> y) -> x s<= (-1 >> y) + if (!match(M, m_Constant())) // Can not do this fold with non-constant. + return nullptr; + if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. + return nullptr; + DstPred = ICmpInst::Predicate::ICMP_SLE; + break; + case ICmpInst::Predicate::ICMP_SGT: + case ICmpInst::Predicate::ICMP_SLE: + return nullptr; + case ICmpInst::Predicate::ICMP_UGT: + case ICmpInst::Predicate::ICMP_ULE: + llvm_unreachable("Instsimplify took care of commut. variant"); + break; + default: + llvm_unreachable("All possible folds are handled."); + } + + // The mask value may be a vector constant that has undefined elements. But it + // may not be safe to propagate those undefs into the new compare, so replace + // those elements by copying an existing, defined, and safe scalar constant. + Type *OpTy = M->getType(); + auto *VecC = dyn_cast<Constant>(M); + auto *OpVTy = dyn_cast<FixedVectorType>(OpTy); + if (OpVTy && VecC && VecC->containsUndefOrPoisonElement()) { + Constant *SafeReplacementConstant = nullptr; + for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) { + if (!isa<UndefValue>(VecC->getAggregateElement(i))) { + SafeReplacementConstant = VecC->getAggregateElement(i); + break; + } + } + assert(SafeReplacementConstant && "Failed to find undef replacement"); + M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant); + } + + return Builder.CreateICmp(DstPred, X, M); +} + +/// Some comparisons can be simplified. +/// In this case, we are looking for comparisons that look like +/// a check for a lossy signed truncation. +/// Folds: (MaskedBits is a constant.) +/// ((%x << MaskedBits) a>> MaskedBits) SrcPred %x +/// Into: +/// (add %x, (1 << (KeptBits-1))) DstPred (1 << KeptBits) +/// Where KeptBits = bitwidth(%x) - MaskedBits +static Value * +foldICmpWithTruncSignExtendedVal(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate SrcPred; + Value *X; + const APInt *C0, *C1; // FIXME: non-splats, potentially with undef. + // We are ok with 'shl' having multiple uses, but 'ashr' must be one-use. + if (!match(&I, m_c_ICmp(SrcPred, + m_OneUse(m_AShr(m_Shl(m_Value(X), m_APInt(C0)), + m_APInt(C1))), + m_Deferred(X)))) + return nullptr; + + // Potential handling of non-splats: for each element: + // * if both are undef, replace with constant 0. + // Because (1<<0) is OK and is 1, and ((1<<0)>>1) is also OK and is 0. + // * if both are not undef, and are different, bailout. + // * else, only one is undef, then pick the non-undef one. + + // The shift amount must be equal. + if (*C0 != *C1) + return nullptr; + const APInt &MaskedBits = *C0; + assert(MaskedBits != 0 && "shift by zero should be folded away already."); + + ICmpInst::Predicate DstPred; + switch (SrcPred) { + case ICmpInst::Predicate::ICMP_EQ: + // ((%x << MaskedBits) a>> MaskedBits) == %x + // => + // (add %x, (1 << (KeptBits-1))) u< (1 << KeptBits) + DstPred = ICmpInst::Predicate::ICMP_ULT; + break; + case ICmpInst::Predicate::ICMP_NE: + // ((%x << MaskedBits) a>> MaskedBits) != %x + // => + // (add %x, (1 << (KeptBits-1))) u>= (1 << KeptBits) + DstPred = ICmpInst::Predicate::ICMP_UGE; + break; + // FIXME: are more folds possible? + default: + return nullptr; + } + + auto *XType = X->getType(); + const unsigned XBitWidth = XType->getScalarSizeInBits(); + const APInt BitWidth = APInt(XBitWidth, XBitWidth); + assert(BitWidth.ugt(MaskedBits) && "shifts should leave some bits untouched"); + + // KeptBits = bitwidth(%x) - MaskedBits + const APInt KeptBits = BitWidth - MaskedBits; + assert(KeptBits.ugt(0) && KeptBits.ult(BitWidth) && "unreachable"); + // ICmpCst = (1 << KeptBits) + const APInt ICmpCst = APInt(XBitWidth, 1).shl(KeptBits); + assert(ICmpCst.isPowerOf2()); + // AddCst = (1 << (KeptBits-1)) + const APInt AddCst = ICmpCst.lshr(1); + assert(AddCst.ult(ICmpCst) && AddCst.isPowerOf2()); + + // T0 = add %x, AddCst + Value *T0 = Builder.CreateAdd(X, ConstantInt::get(XType, AddCst)); + // T1 = T0 DstPred ICmpCst + Value *T1 = Builder.CreateICmp(DstPred, T0, ConstantInt::get(XType, ICmpCst)); + + return T1; +} + +// Given pattern: +// icmp eq/ne (and ((x shift Q), (y oppositeshift K))), 0 +// we should move shifts to the same hand of 'and', i.e. rewrite as +// icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) +// We are only interested in opposite logical shifts here. +// One of the shifts can be truncated. +// If we can, we want to end up creating 'lshr' shift. +static Value * +foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, + InstCombiner::BuilderTy &Builder) { + if (!I.isEquality() || !match(I.getOperand(1), m_Zero()) || + !I.getOperand(0)->hasOneUse()) + return nullptr; + + auto m_AnyLogicalShift = m_LogicalShift(m_Value(), m_Value()); + + // Look for an 'and' of two logical shifts, one of which may be truncated. + // We use m_TruncOrSelf() on the RHS to correctly handle commutative case. + Instruction *XShift, *MaybeTruncation, *YShift; + if (!match( + I.getOperand(0), + m_c_And(m_CombineAnd(m_AnyLogicalShift, m_Instruction(XShift)), + m_CombineAnd(m_TruncOrSelf(m_CombineAnd( + m_AnyLogicalShift, m_Instruction(YShift))), + m_Instruction(MaybeTruncation))))) + return nullptr; + + // We potentially looked past 'trunc', but only when matching YShift, + // therefore YShift must have the widest type. + Instruction *WidestShift = YShift; + // Therefore XShift must have the shallowest type. + // Or they both have identical types if there was no truncation. + Instruction *NarrowestShift = XShift; + + Type *WidestTy = WidestShift->getType(); + Type *NarrowestTy = NarrowestShift->getType(); + assert(NarrowestTy == I.getOperand(0)->getType() && + "We did not look past any shifts while matching XShift though."); + bool HadTrunc = WidestTy != I.getOperand(0)->getType(); + + // If YShift is a 'lshr', swap the shifts around. + if (match(YShift, m_LShr(m_Value(), m_Value()))) + std::swap(XShift, YShift); + + // The shifts must be in opposite directions. + auto XShiftOpcode = XShift->getOpcode(); + if (XShiftOpcode == YShift->getOpcode()) + return nullptr; // Do not care about same-direction shifts here. + + Value *X, *XShAmt, *Y, *YShAmt; + match(XShift, m_BinOp(m_Value(X), m_ZExtOrSelf(m_Value(XShAmt)))); + match(YShift, m_BinOp(m_Value(Y), m_ZExtOrSelf(m_Value(YShAmt)))); + + // If one of the values being shifted is a constant, then we will end with + // and+icmp, and [zext+]shift instrs will be constant-folded. If they are not, + // however, we will need to ensure that we won't increase instruction count. + if (!isa<Constant>(X) && !isa<Constant>(Y)) { + // At least one of the hands of the 'and' should be one-use shift. + if (!match(I.getOperand(0), + m_c_And(m_OneUse(m_AnyLogicalShift), m_Value()))) + return nullptr; + if (HadTrunc) { + // Due to the 'trunc', we will need to widen X. For that either the old + // 'trunc' or the shift amt in the non-truncated shift should be one-use. + if (!MaybeTruncation->hasOneUse() && + !NarrowestShift->getOperand(1)->hasOneUse()) + return nullptr; + } + } + + // We have two shift amounts from two different shifts. The types of those + // shift amounts may not match. If that's the case let's bailout now. + if (XShAmt->getType() != YShAmt->getType()) + return nullptr; + + // As input, we have the following pattern: + // icmp eq/ne (and ((x shift Q), (y oppositeshift K))), 0 + // We want to rewrite that as: + // icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) + // While we know that originally (Q+K) would not overflow + // (because 2 * (N-1) u<= iN -1), we have looked past extensions of + // shift amounts. so it may now overflow in smaller bitwidth. + // To ensure that does not happen, we need to ensure that the total maximal + // shift amount is still representable in that smaller bit width. + unsigned MaximalPossibleTotalShiftAmount = + (WidestTy->getScalarSizeInBits() - 1) + + (NarrowestTy->getScalarSizeInBits() - 1); + APInt MaximalRepresentableShiftAmount = + APInt::getAllOnes(XShAmt->getType()->getScalarSizeInBits()); + if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount)) + return nullptr; + + // Can we fold (XShAmt+YShAmt) ? + auto *NewShAmt = dyn_cast_or_null<Constant>( + simplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false, + /*isNUW=*/false, SQ.getWithInstruction(&I))); + if (!NewShAmt) + return nullptr; + NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, WidestTy); + unsigned WidestBitWidth = WidestTy->getScalarSizeInBits(); + + // Is the new shift amount smaller than the bit width? + // FIXME: could also rely on ConstantRange. + if (!match(NewShAmt, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, + APInt(WidestBitWidth, WidestBitWidth)))) + return nullptr; + + // An extra legality check is needed if we had trunc-of-lshr. + if (HadTrunc && match(WidestShift, m_LShr(m_Value(), m_Value()))) { + auto CanFold = [NewShAmt, WidestBitWidth, NarrowestShift, SQ, + WidestShift]() { + // It isn't obvious whether it's worth it to analyze non-constants here. + // Also, let's basically give up on non-splat cases, pessimizing vectors. + // If *any* of these preconditions matches we can perform the fold. + Constant *NewShAmtSplat = NewShAmt->getType()->isVectorTy() + ? NewShAmt->getSplatValue() + : NewShAmt; + // If it's edge-case shift (by 0 or by WidestBitWidth-1) we can fold. + if (NewShAmtSplat && + (NewShAmtSplat->isNullValue() || + NewShAmtSplat->getUniqueInteger() == WidestBitWidth - 1)) + return true; + // We consider *min* leading zeros so a single outlier + // blocks the transform as opposed to allowing it. + if (auto *C = dyn_cast<Constant>(NarrowestShift->getOperand(0))) { + KnownBits Known = computeKnownBits(C, SQ.DL); + unsigned MinLeadZero = Known.countMinLeadingZeros(); + // If the value being shifted has at most lowest bit set we can fold. + unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero; + if (MaxActiveBits <= 1) + return true; + // Precondition: NewShAmt u<= countLeadingZeros(C) + if (NewShAmtSplat && NewShAmtSplat->getUniqueInteger().ule(MinLeadZero)) + return true; + } + if (auto *C = dyn_cast<Constant>(WidestShift->getOperand(0))) { + KnownBits Known = computeKnownBits(C, SQ.DL); + unsigned MinLeadZero = Known.countMinLeadingZeros(); + // If the value being shifted has at most lowest bit set we can fold. + unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero; + if (MaxActiveBits <= 1) + return true; + // Precondition: ((WidestBitWidth-1)-NewShAmt) u<= countLeadingZeros(C) + if (NewShAmtSplat) { + APInt AdjNewShAmt = + (WidestBitWidth - 1) - NewShAmtSplat->getUniqueInteger(); + if (AdjNewShAmt.ule(MinLeadZero)) + return true; + } + } + return false; // Can't tell if it's ok. + }; + if (!CanFold()) + return nullptr; + } + + // All good, we can do this fold. + X = Builder.CreateZExt(X, WidestTy); + Y = Builder.CreateZExt(Y, WidestTy); + // The shift is the same that was for X. + Value *T0 = XShiftOpcode == Instruction::BinaryOps::LShr + ? Builder.CreateLShr(X, NewShAmt) + : Builder.CreateShl(X, NewShAmt); + Value *T1 = Builder.CreateAnd(T0, Y); + return Builder.CreateICmp(I.getPredicate(), T1, + Constant::getNullValue(WidestTy)); +} + +/// Fold +/// (-1 u/ x) u< y +/// ((x * y) ?/ x) != y +/// to +/// @llvm.?mul.with.overflow(x, y) plus extraction of overflow bit +/// Note that the comparison is commutative, while inverted (u>=, ==) predicate +/// will mean that we are looking for the opposite answer. +Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) { + ICmpInst::Predicate Pred; + Value *X, *Y; + Instruction *Mul; + Instruction *Div; + bool NeedNegation; + // Look for: (-1 u/ x) u</u>= y + if (!I.isEquality() && + match(&I, m_c_ICmp(Pred, + m_CombineAnd(m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), + m_Instruction(Div)), + m_Value(Y)))) { + Mul = nullptr; + + // Are we checking that overflow does not happen, or does happen? + switch (Pred) { + case ICmpInst::Predicate::ICMP_ULT: + NeedNegation = false; + break; // OK + case ICmpInst::Predicate::ICMP_UGE: + NeedNegation = true; + break; // OK + default: + return nullptr; // Wrong predicate. + } + } else // Look for: ((x * y) / x) !=/== y + if (I.isEquality() && + match(&I, + m_c_ICmp(Pred, m_Value(Y), + m_CombineAnd( + m_OneUse(m_IDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), + m_Value(X)), + m_Instruction(Mul)), + m_Deferred(X))), + m_Instruction(Div))))) { + NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ; + } else + return nullptr; + + BuilderTy::InsertPointGuard Guard(Builder); + // If the pattern included (x * y), we'll want to insert new instructions + // right before that original multiplication so that we can replace it. + bool MulHadOtherUses = Mul && !Mul->hasOneUse(); + if (MulHadOtherUses) + Builder.SetInsertPoint(Mul); + + Function *F = Intrinsic::getDeclaration(I.getModule(), + Div->getOpcode() == Instruction::UDiv + ? Intrinsic::umul_with_overflow + : Intrinsic::smul_with_overflow, + X->getType()); + CallInst *Call = Builder.CreateCall(F, {X, Y}, "mul"); + + // If the multiplication was used elsewhere, to ensure that we don't leave + // "duplicate" instructions, replace uses of that original multiplication + // with the multiplication result from the with.overflow intrinsic. + if (MulHadOtherUses) + replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "mul.val")); + + Value *Res = Builder.CreateExtractValue(Call, 1, "mul.ov"); + if (NeedNegation) // This technically increases instruction count. + Res = Builder.CreateNot(Res, "mul.not.ov"); + + // If we replaced the mul, erase it. Do this after all uses of Builder, + // as the mul is used as insertion point. + if (MulHadOtherUses) + eraseInstFromFunction(*Mul); + + return Res; +} + +static Instruction *foldICmpXNegX(ICmpInst &I) { + CmpInst::Predicate Pred; + Value *X; + if (!match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X)))) + return nullptr; + + if (ICmpInst::isSigned(Pred)) + Pred = ICmpInst::getSwappedPredicate(Pred); + else if (ICmpInst::isUnsigned(Pred)) + Pred = ICmpInst::getSignedPredicate(Pred); + // else for equality-comparisons just keep the predicate. + + return ICmpInst::Create(Instruction::ICmp, Pred, X, + Constant::getNullValue(X->getType()), I.getName()); +} + +/// Try to fold icmp (binop), X or icmp X, (binop). +/// TODO: A large part of this logic is duplicated in InstSimplify's +/// simplifyICmpWithBinOp(). We should be able to share that and avoid the code +/// duplication. +Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, + const SimplifyQuery &SQ) { + const SimplifyQuery Q = SQ.getWithInstruction(&I); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Special logic for binary operators. + BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0); + BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1); + if (!BO0 && !BO1) + return nullptr; + + if (Instruction *NewICmp = foldICmpXNegX(I)) + return NewICmp; + + const CmpInst::Predicate Pred = I.getPredicate(); + Value *X; + + // Convert add-with-unsigned-overflow comparisons into a 'not' with compare. + // (Op1 + X) u</u>= Op1 --> ~Op1 u</u>= X + if (match(Op0, m_OneUse(m_c_Add(m_Specific(Op1), m_Value(X)))) && + (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) + return new ICmpInst(Pred, Builder.CreateNot(Op1), X); + // Op0 u>/u<= (Op0 + X) --> X u>/u<= ~Op0 + if (match(Op1, m_OneUse(m_c_Add(m_Specific(Op0), m_Value(X)))) && + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) + return new ICmpInst(Pred, X, Builder.CreateNot(Op0)); + + { + // (Op1 + X) + C u</u>= Op1 --> ~C - X u</u>= Op1 + Constant *C; + if (match(Op0, m_OneUse(m_Add(m_c_Add(m_Specific(Op1), m_Value(X)), + m_ImmConstant(C)))) && + (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) { + Constant *C2 = ConstantExpr::getNot(C); + return new ICmpInst(Pred, Builder.CreateSub(C2, X), Op1); + } + // Op0 u>/u<= (Op0 + X) + C --> Op0 u>/u<= ~C - X + if (match(Op1, m_OneUse(m_Add(m_c_Add(m_Specific(Op0), m_Value(X)), + m_ImmConstant(C)))) && + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) { + Constant *C2 = ConstantExpr::getNot(C); + return new ICmpInst(Pred, Op0, Builder.CreateSub(C2, X)); + } + } + + { + // Similar to above: an unsigned overflow comparison may use offset + mask: + // ((Op1 + C) & C) u< Op1 --> Op1 != 0 + // ((Op1 + C) & C) u>= Op1 --> Op1 == 0 + // Op0 u> ((Op0 + C) & C) --> Op0 != 0 + // Op0 u<= ((Op0 + C) & C) --> Op0 == 0 + BinaryOperator *BO; + const APInt *C; + if ((Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) && + match(Op0, m_And(m_BinOp(BO), m_LowBitMask(C))) && + match(BO, m_Add(m_Specific(Op1), m_SpecificIntAllowUndef(*C)))) { + CmpInst::Predicate NewPred = + Pred == ICmpInst::ICMP_ULT ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; + Constant *Zero = ConstantInt::getNullValue(Op1->getType()); + return new ICmpInst(NewPred, Op1, Zero); + } + + if ((Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE) && + match(Op1, m_And(m_BinOp(BO), m_LowBitMask(C))) && + match(BO, m_Add(m_Specific(Op0), m_SpecificIntAllowUndef(*C)))) { + CmpInst::Predicate NewPred = + Pred == ICmpInst::ICMP_UGT ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; + Constant *Zero = ConstantInt::getNullValue(Op1->getType()); + return new ICmpInst(NewPred, Op0, Zero); + } + } + + bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; + if (BO0 && isa<OverflowingBinaryOperator>(BO0)) + NoOp0WrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && BO0->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && BO0->hasNoSignedWrap()); + if (BO1 && isa<OverflowingBinaryOperator>(BO1)) + NoOp1WrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && BO1->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && BO1->hasNoSignedWrap()); + + // Analyze the case when either Op0 or Op1 is an add instruction. + // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). + Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; + if (BO0 && BO0->getOpcode() == Instruction::Add) { + A = BO0->getOperand(0); + B = BO0->getOperand(1); + } + if (BO1 && BO1->getOpcode() == Instruction::Add) { + C = BO1->getOperand(0); + D = BO1->getOperand(1); + } + + // icmp (A+B), A -> icmp B, 0 for equalities or if there is no overflow. + // icmp (A+B), B -> icmp A, 0 for equalities or if there is no overflow. + if ((A == Op1 || B == Op1) && NoOp0WrapProblem) + return new ICmpInst(Pred, A == Op1 ? B : A, + Constant::getNullValue(Op1->getType())); + + // icmp C, (C+D) -> icmp 0, D for equalities or if there is no overflow. + // icmp D, (C+D) -> icmp 0, C for equalities or if there is no overflow. + if ((C == Op0 || D == Op0) && NoOp1WrapProblem) + return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), + C == Op0 ? D : C); + + // icmp (A+B), (A+D) -> icmp B, D for equalities or if there is no overflow. + if (A && C && (A == C || A == D || B == C || B == D) && NoOp0WrapProblem && + NoOp1WrapProblem) { + // Determine Y and Z in the form icmp (X+Y), (X+Z). + Value *Y, *Z; + if (A == C) { + // C + B == C + D -> B == D + Y = B; + Z = D; + } else if (A == D) { + // D + B == C + D -> B == C + Y = B; + Z = C; + } else if (B == C) { + // A + C == C + D -> A == D + Y = A; + Z = D; + } else { + assert(B == D); + // A + D == C + D -> A == C + Y = A; + Z = C; + } + return new ICmpInst(Pred, Y, Z); + } + + // icmp slt (A + -1), Op1 -> icmp sle A, Op1 + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT && + match(B, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SLE, A, Op1); + + // icmp sge (A + -1), Op1 -> icmp sgt A, Op1 + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE && + match(B, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SGT, A, Op1); + + // icmp sle (A + 1), Op1 -> icmp slt A, Op1 + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_SLT, A, Op1); + + // icmp sgt (A + 1), Op1 -> icmp sge A, Op1 + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_SGE, A, Op1); + + // icmp sgt Op0, (C + -1) -> icmp sge Op0, C + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT && + match(D, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SGE, Op0, C); + + // icmp sle Op0, (C + -1) -> icmp slt Op0, C + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE && + match(D, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SLT, Op0, C); + + // icmp sge Op0, (C + 1) -> icmp sgt Op0, C + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_SGT, Op0, C); + + // icmp slt Op0, (C + 1) -> icmp sle Op0, C + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); + + // TODO: The subtraction-related identities shown below also hold, but + // canonicalization from (X -nuw 1) to (X + -1) means that the combinations + // wouldn't happen even if they were implemented. + // + // icmp ult (A - 1), Op1 -> icmp ule A, Op1 + // icmp uge (A - 1), Op1 -> icmp ugt A, Op1 + // icmp ugt Op0, (C - 1) -> icmp uge Op0, C + // icmp ule Op0, (C - 1) -> icmp ult Op0, C + + // icmp ule (A + 1), Op0 -> icmp ult A, Op1 + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_ULT, A, Op1); + + // icmp ugt (A + 1), Op0 -> icmp uge A, Op1 + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_UGE, A, Op1); + + // icmp uge Op0, (C + 1) -> icmp ugt Op0, C + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_UGT, Op0, C); + + // icmp ult Op0, (C + 1) -> icmp ule Op0, C + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_ULE, Op0, C); + + // if C1 has greater magnitude than C2: + // icmp (A + C1), (C + C2) -> icmp (A + C3), C + // s.t. C3 = C1 - C2 + // + // if C2 has greater magnitude than C1: + // icmp (A + C1), (C + C2) -> icmp A, (C + C3) + // s.t. C3 = C2 - C1 + if (A && C && NoOp0WrapProblem && NoOp1WrapProblem && + (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) { + const APInt *AP1, *AP2; + // TODO: Support non-uniform vectors. + // TODO: Allow undef passthrough if B AND D's element is undef. + if (match(B, m_APIntAllowUndef(AP1)) && match(D, m_APIntAllowUndef(AP2)) && + AP1->isNegative() == AP2->isNegative()) { + APInt AP1Abs = AP1->abs(); + APInt AP2Abs = AP2->abs(); + if (AP1Abs.uge(AP2Abs)) { + APInt Diff = *AP1 - *AP2; + bool HasNUW = BO0->hasNoUnsignedWrap() && Diff.ule(*AP1); + bool HasNSW = BO0->hasNoSignedWrap(); + Constant *C3 = Constant::getIntegerValue(BO0->getType(), Diff); + Value *NewAdd = Builder.CreateAdd(A, C3, "", HasNUW, HasNSW); + return new ICmpInst(Pred, NewAdd, C); + } else { + APInt Diff = *AP2 - *AP1; + bool HasNUW = BO1->hasNoUnsignedWrap() && Diff.ule(*AP2); + bool HasNSW = BO1->hasNoSignedWrap(); + Constant *C3 = Constant::getIntegerValue(BO0->getType(), Diff); + Value *NewAdd = Builder.CreateAdd(C, C3, "", HasNUW, HasNSW); + return new ICmpInst(Pred, A, NewAdd); + } + } + Constant *Cst1, *Cst2; + if (match(B, m_ImmConstant(Cst1)) && match(D, m_ImmConstant(Cst2)) && + ICmpInst::isEquality(Pred)) { + Constant *Diff = ConstantExpr::getSub(Cst2, Cst1); + Value *NewAdd = Builder.CreateAdd(C, Diff); + return new ICmpInst(Pred, A, NewAdd); + } + } + + // Analyze the case when either Op0 or Op1 is a sub instruction. + // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). + A = nullptr; + B = nullptr; + C = nullptr; + D = nullptr; + if (BO0 && BO0->getOpcode() == Instruction::Sub) { + A = BO0->getOperand(0); + B = BO0->getOperand(1); + } + if (BO1 && BO1->getOpcode() == Instruction::Sub) { + C = BO1->getOperand(0); + D = BO1->getOperand(1); + } + + // icmp (A-B), A -> icmp 0, B for equalities or if there is no overflow. + if (A == Op1 && NoOp0WrapProblem) + return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); + // icmp C, (C-D) -> icmp D, 0 for equalities or if there is no overflow. + if (C == Op0 && NoOp1WrapProblem) + return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); + + // Convert sub-with-unsigned-overflow comparisons into a comparison of args. + // (A - B) u>/u<= A --> B u>/u<= A + if (A == Op1 && (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) + return new ICmpInst(Pred, B, A); + // C u</u>= (C - D) --> C u</u>= D + if (C == Op0 && (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) + return new ICmpInst(Pred, C, D); + // (A - B) u>=/u< A --> B u>/u<= A iff B != 0 + if (A == Op1 && (Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_ULT) && + isKnownNonZero(B, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), B, A); + // C u<=/u> (C - D) --> C u</u>= D iff B != 0 + if (C == Op0 && (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) && + isKnownNonZero(D, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), C, D); + + // icmp (A-B), (C-B) -> icmp A, C for equalities or if there is no overflow. + if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem) + return new ICmpInst(Pred, A, C); + + // icmp (A-B), (A-D) -> icmp D, B for equalities or if there is no overflow. + if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem) + return new ICmpInst(Pred, D, B); + + // icmp (0-X) < cst --> x > -cst + if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) { + Value *X; + if (match(BO0, m_Neg(m_Value(X)))) + if (Constant *RHSC = dyn_cast<Constant>(Op1)) + if (RHSC->isNotMinSignedValue()) + return new ICmpInst(I.getSwappedPredicate(), X, + ConstantExpr::getNeg(RHSC)); + } + + { + // Try to remove shared constant multiplier from equality comparison: + // X * C == Y * C (with no overflowing/aliasing) --> X == Y + Value *X, *Y; + const APInt *C; + if (match(Op0, m_Mul(m_Value(X), m_APInt(C))) && *C != 0 && + match(Op1, m_Mul(m_Value(Y), m_SpecificInt(*C))) && I.isEquality()) + if (!C->countTrailingZeros() || + (BO0 && BO1 && BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) || + (BO0 && BO1 && BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) + return new ICmpInst(Pred, X, Y); + } + + BinaryOperator *SRem = nullptr; + // icmp (srem X, Y), Y + if (BO0 && BO0->getOpcode() == Instruction::SRem && Op1 == BO0->getOperand(1)) + SRem = BO0; + // icmp Y, (srem X, Y) + else if (BO1 && BO1->getOpcode() == Instruction::SRem && + Op0 == BO1->getOperand(1)) + SRem = BO1; + if (SRem) { + // We don't check hasOneUse to avoid increasing register pressure because + // the value we use is the same value this instruction was already using. + switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) { + default: + break; + case ICmpInst::ICMP_EQ: + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + case ICmpInst::ICMP_NE: + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1), + Constant::getAllOnesValue(SRem->getType())); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return new ICmpInst(ICmpInst::ICMP_SLT, SRem->getOperand(1), + Constant::getNullValue(SRem->getType())); + } + } + + if (BO0 && BO1 && BO0->getOpcode() == BO1->getOpcode() && BO0->hasOneUse() && + BO1->hasOneUse() && BO0->getOperand(1) == BO1->getOperand(1)) { + switch (BO0->getOpcode()) { + default: + break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Xor: { + if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + + const APInt *C; + if (match(BO0->getOperand(1), m_APInt(C))) { + // icmp u/s (a ^ signmask), (b ^ signmask) --> icmp s/u a, b + if (C->isSignMask()) { + ICmpInst::Predicate NewPred = I.getFlippedSignednessPredicate(); + return new ICmpInst(NewPred, BO0->getOperand(0), BO1->getOperand(0)); + } + + // icmp u/s (a ^ maxsignval), (b ^ maxsignval) --> icmp s/u' a, b + if (BO0->getOpcode() == Instruction::Xor && C->isMaxSignedValue()) { + ICmpInst::Predicate NewPred = I.getFlippedSignednessPredicate(); + NewPred = I.getSwappedPredicate(NewPred); + return new ICmpInst(NewPred, BO0->getOperand(0), BO1->getOperand(0)); + } + } + break; + } + case Instruction::Mul: { + if (!I.isEquality()) + break; + + const APInt *C; + if (match(BO0->getOperand(1), m_APInt(C)) && !C->isZero() && + !C->isOne()) { + // icmp eq/ne (X * C), (Y * C) --> icmp (X & Mask), (Y & Mask) + // Mask = -1 >> count-trailing-zeros(C). + if (unsigned TZs = C->countTrailingZeros()) { + Constant *Mask = ConstantInt::get( + BO0->getType(), + APInt::getLowBitsSet(C->getBitWidth(), C->getBitWidth() - TZs)); + Value *And1 = Builder.CreateAnd(BO0->getOperand(0), Mask); + Value *And2 = Builder.CreateAnd(BO1->getOperand(0), Mask); + return new ICmpInst(Pred, And1, And2); + } + } + break; + } + case Instruction::UDiv: + case Instruction::LShr: + if (I.isSigned() || !BO0->isExact() || !BO1->isExact()) + break; + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + + case Instruction::SDiv: + if (!I.isEquality() || !BO0->isExact() || !BO1->isExact()) + break; + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + + case Instruction::AShr: + if (!BO0->isExact() || !BO1->isExact()) + break; + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + + case Instruction::Shl: { + bool NUW = BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap(); + bool NSW = BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap(); + if (!NUW && !NSW) + break; + if (!NSW && I.isSigned()) + break; + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + } + } + } + + if (BO0) { + // Transform A & (L - 1) `ult` L --> L != 0 + auto LSubOne = m_Add(m_Specific(Op1), m_AllOnes()); + auto BitwiseAnd = m_c_And(m_Value(), LSubOne); + + if (match(BO0, BitwiseAnd) && Pred == ICmpInst::ICMP_ULT) { + auto *Zero = Constant::getNullValue(BO0->getType()); + return new ICmpInst(ICmpInst::ICMP_NE, Op1, Zero); + } + } + + if (Value *V = foldMultiplicationOverflowCheck(I)) + return replaceInstUsesWith(I, V); + + if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) + return replaceInstUsesWith(I, V); + + if (Value *V = foldICmpWithTruncSignExtendedVal(I, Builder)) + return replaceInstUsesWith(I, V); + + if (Value *V = foldShiftIntoShiftInAnotherHandOfAndInICmp(I, SQ, Builder)) + return replaceInstUsesWith(I, V); + + return nullptr; +} + +/// Fold icmp Pred min|max(X, Y), X. +static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op0 = Cmp.getOperand(0); + Value *X = Cmp.getOperand(1); + + // Canonicalize minimum or maximum operand to LHS of the icmp. + if (match(X, m_c_SMin(m_Specific(Op0), m_Value())) || + match(X, m_c_SMax(m_Specific(Op0), m_Value())) || + match(X, m_c_UMin(m_Specific(Op0), m_Value())) || + match(X, m_c_UMax(m_Specific(Op0), m_Value()))) { + std::swap(Op0, X); + Pred = Cmp.getSwappedPredicate(); + } + + Value *Y; + if (match(Op0, m_c_SMin(m_Specific(X), m_Value(Y)))) { + // smin(X, Y) == X --> X s<= Y + // smin(X, Y) s>= X --> X s<= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SGE) + return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); + + // smin(X, Y) != X --> X s> Y + // smin(X, Y) s< X --> X s> Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SLT) + return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); + + // These cases should be handled in InstSimplify: + // smin(X, Y) s<= X --> true + // smin(X, Y) s> X --> false + return nullptr; + } + + if (match(Op0, m_c_SMax(m_Specific(X), m_Value(Y)))) { + // smax(X, Y) == X --> X s>= Y + // smax(X, Y) s<= X --> X s>= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); + + // smax(X, Y) != X --> X s< Y + // smax(X, Y) s> X --> X s< Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); + + // These cases should be handled in InstSimplify: + // smax(X, Y) s>= X --> true + // smax(X, Y) s< X --> false + return nullptr; + } + + if (match(Op0, m_c_UMin(m_Specific(X), m_Value(Y)))) { + // umin(X, Y) == X --> X u<= Y + // umin(X, Y) u>= X --> X u<= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_UGE) + return new ICmpInst(ICmpInst::ICMP_ULE, X, Y); + + // umin(X, Y) != X --> X u> Y + // umin(X, Y) u< X --> X u> Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT) + return new ICmpInst(ICmpInst::ICMP_UGT, X, Y); + + // These cases should be handled in InstSimplify: + // umin(X, Y) u<= X --> true + // umin(X, Y) u> X --> false + return nullptr; + } + + if (match(Op0, m_c_UMax(m_Specific(X), m_Value(Y)))) { + // umax(X, Y) == X --> X u>= Y + // umax(X, Y) u<= X --> X u>= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_ULE) + return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); + + // umax(X, Y) != X --> X u< Y + // umax(X, Y) u> X --> X u< Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); + + // These cases should be handled in InstSimplify: + // umax(X, Y) u>= X --> true + // umax(X, Y) u< X --> false + return nullptr; + } + + return nullptr; +} + +Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { + if (!I.isEquality()) + return nullptr; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + const CmpInst::Predicate Pred = I.getPredicate(); + Value *A, *B, *C, *D; + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { + if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 + Value *OtherVal = A == Op1 ? B : A; + return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType())); + } + + if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) { + // A^c1 == C^c2 --> A == C^(c1^c2) + ConstantInt *C1, *C2; + if (match(B, m_ConstantInt(C1)) && match(D, m_ConstantInt(C2)) && + Op1->hasOneUse()) { + Constant *NC = Builder.getInt(C1->getValue() ^ C2->getValue()); + Value *Xor = Builder.CreateXor(C, NC); + return new ICmpInst(Pred, A, Xor); + } + + // A^B == A^D -> B == D + if (A == C) + return new ICmpInst(Pred, B, D); + if (A == D) + return new ICmpInst(Pred, B, C); + if (B == C) + return new ICmpInst(Pred, A, D); + if (B == D) + return new ICmpInst(Pred, A, C); + } + } + + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) { + // A == (A^B) -> B == 0 + Value *OtherVal = A == Op0 ? B : A; + return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType())); + } + + // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 + if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && + match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { + Value *X = nullptr, *Y = nullptr, *Z = nullptr; + + if (A == C) { + X = B; + Y = D; + Z = A; + } else if (A == D) { + X = B; + Y = C; + Z = A; + } else if (B == C) { + X = A; + Y = D; + Z = B; + } else if (B == D) { + X = A; + Y = C; + Z = B; + } + + if (X) { // Build (X^Y) & Z + Op1 = Builder.CreateXor(X, Y); + Op1 = Builder.CreateAnd(Op1, Z); + return new ICmpInst(Pred, Op1, Constant::getNullValue(Op1->getType())); + } + } + + { + // Similar to above, but specialized for constant because invert is needed: + // (X | C) == (Y | C) --> (X ^ Y) & ~C == 0 + Value *X, *Y; + Constant *C; + if (match(Op0, m_OneUse(m_Or(m_Value(X), m_Constant(C)))) && + match(Op1, m_OneUse(m_Or(m_Value(Y), m_Specific(C))))) { + Value *Xor = Builder.CreateXor(X, Y); + Value *And = Builder.CreateAnd(Xor, ConstantExpr::getNot(C)); + return new ICmpInst(Pred, And, Constant::getNullValue(And->getType())); + } + } + + // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B) + // and (B & (1<<X)-1) == (zext A) --> A == (trunc B) + ConstantInt *Cst1; + if ((Op0->hasOneUse() && match(Op0, m_ZExt(m_Value(A))) && + match(Op1, m_And(m_Value(B), m_ConstantInt(Cst1)))) || + (Op1->hasOneUse() && match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) && + match(Op1, m_ZExt(m_Value(A))))) { + APInt Pow2 = Cst1->getValue() + 1; + if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) && + Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth()) + return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType())); + } + + // (A >> C) == (B >> C) --> (A^B) u< (1 << C) + // For lshr and ashr pairs. + const APInt *AP1, *AP2; + if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_APIntAllowUndef(AP1)))) && + match(Op1, m_OneUse(m_LShr(m_Value(B), m_APIntAllowUndef(AP2))))) || + (match(Op0, m_OneUse(m_AShr(m_Value(A), m_APIntAllowUndef(AP1)))) && + match(Op1, m_OneUse(m_AShr(m_Value(B), m_APIntAllowUndef(AP2)))))) { + if (AP1 != AP2) + return nullptr; + unsigned TypeBits = AP1->getBitWidth(); + unsigned ShAmt = AP1->getLimitedValue(TypeBits); + if (ShAmt < TypeBits && ShAmt != 0) { + ICmpInst::Predicate NewPred = + Pred == ICmpInst::ICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; + Value *Xor = Builder.CreateXor(A, B, I.getName() + ".unshifted"); + APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt); + return new ICmpInst(NewPred, Xor, ConstantInt::get(A->getType(), CmpVal)); + } + } + + // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0 + if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) && + match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) { + unsigned TypeBits = Cst1->getBitWidth(); + unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); + if (ShAmt < TypeBits && ShAmt != 0) { + Value *Xor = Builder.CreateXor(A, B, I.getName() + ".unshifted"); + APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt); + Value *And = Builder.CreateAnd(Xor, Builder.getInt(AndVal), + I.getName() + ".mask"); + return new ICmpInst(Pred, And, Constant::getNullValue(Cst1->getType())); + } + } + + // Transform "icmp eq (trunc (lshr(X, cst1)), cst" to + // "icmp (and X, mask), cst" + uint64_t ShAmt = 0; + if (Op0->hasOneUse() && + match(Op0, m_Trunc(m_OneUse(m_LShr(m_Value(A), m_ConstantInt(ShAmt))))) && + match(Op1, m_ConstantInt(Cst1)) && + // Only do this when A has multiple uses. This is most important to do + // when it exposes other optimizations. + !A->hasOneUse()) { + unsigned ASize = cast<IntegerType>(A->getType())->getPrimitiveSizeInBits(); + + if (ShAmt < ASize) { + APInt MaskV = + APInt::getLowBitsSet(ASize, Op0->getType()->getPrimitiveSizeInBits()); + MaskV <<= ShAmt; + + APInt CmpV = Cst1->getValue().zext(ASize); + CmpV <<= ShAmt; + + Value *Mask = Builder.CreateAnd(A, Builder.getInt(MaskV)); + return new ICmpInst(Pred, Mask, Builder.getInt(CmpV)); + } + } + + if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I)) + return ICmp; + + // Canonicalize checking for a power-of-2-or-zero value: + // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) + // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants) + if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()), + m_Deferred(A)))) || + !match(Op1, m_ZeroInt())) + A = nullptr; + + // (A & -A) == A --> ctpop(A) < 2 (four commuted variants) + // (-A & A) != A --> ctpop(A) > 1 (four commuted variants) + if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1))))) + A = Op1; + else if (match(Op1, + m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0))))) + A = Op0; + + if (A) { + Type *Ty = A->getType(); + CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A); + return Pred == ICmpInst::ICMP_EQ + ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, ConstantInt::get(Ty, 2)) + : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, ConstantInt::get(Ty, 1)); + } + + // Match icmp eq (trunc (lshr A, BW), (ashr (trunc A), BW-1)), which checks the + // top BW/2 + 1 bits are all the same. Create "A >=s INT_MIN && A <=s INT_MAX", + // which we generate as "icmp ult (add A, 2^(BW-1)), 2^BW" to skip a few steps + // of instcombine. + unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); + if (match(Op0, m_AShr(m_Trunc(m_Value(A)), m_SpecificInt(BitWidth - 1))) && + match(Op1, m_Trunc(m_LShr(m_Specific(A), m_SpecificInt(BitWidth)))) && + A->getType()->getScalarSizeInBits() == BitWidth * 2 && + (I.getOperand(0)->hasOneUse() || I.getOperand(1)->hasOneUse())) { + APInt C = APInt::getOneBitSet(BitWidth * 2, BitWidth - 1); + Value *Add = Builder.CreateAdd(A, ConstantInt::get(A->getType(), C)); + return new ICmpInst(Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULT + : ICmpInst::ICMP_UGE, + Add, ConstantInt::get(A->getType(), C.shl(1))); + } + + return nullptr; +} + +static Instruction *foldICmpWithTrunc(ICmpInst &ICmp, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred = ICmp.getPredicate(); + Value *Op0 = ICmp.getOperand(0), *Op1 = ICmp.getOperand(1); + + // Try to canonicalize trunc + compare-to-constant into a mask + cmp. + // The trunc masks high bits while the compare may effectively mask low bits. + Value *X; + const APInt *C; + if (!match(Op0, m_OneUse(m_Trunc(m_Value(X)))) || !match(Op1, m_APInt(C))) + return nullptr; + + // This matches patterns corresponding to tests of the signbit as well as: + // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?) + // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?) + APInt Mask; + if (decomposeBitTestICmp(Op0, Op1, Pred, X, Mask, true /* WithTrunc */)) { + Value *And = Builder.CreateAnd(X, Mask); + Constant *Zero = ConstantInt::getNullValue(X->getType()); + return new ICmpInst(Pred, And, Zero); + } + + unsigned SrcBits = X->getType()->getScalarSizeInBits(); + if (Pred == ICmpInst::ICMP_ULT && C->isNegatedPowerOf2()) { + // If C is a negative power-of-2 (high-bit mask): + // (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?) + Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC); + } + + if (Pred == ICmpInst::ICMP_UGT && (~*C).isPowerOf2()) { + // If C is not-of-power-of-2 (one clear bit): + // (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?) + Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC); + } + + return nullptr; +} + +Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { + assert(isa<CastInst>(ICmp.getOperand(0)) && "Expected cast for operand 0"); + auto *CastOp0 = cast<CastInst>(ICmp.getOperand(0)); + Value *X; + if (!match(CastOp0, m_ZExtOrSExt(m_Value(X)))) + return nullptr; + + bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt; + bool IsSignedCmp = ICmp.isSigned(); + + // icmp Pred (ext X), (ext Y) + Value *Y; + if (match(ICmp.getOperand(1), m_ZExtOrSExt(m_Value(Y)))) { + bool IsZext0 = isa<ZExtOperator>(ICmp.getOperand(0)); + bool IsZext1 = isa<ZExtOperator>(ICmp.getOperand(1)); + + // If we have mismatched casts, treat the zext of a non-negative source as + // a sext to simulate matching casts. Otherwise, we are done. + // TODO: Can we handle some predicates (equality) without non-negative? + if (IsZext0 != IsZext1) { + if ((IsZext0 && isKnownNonNegative(X, DL, 0, &AC, &ICmp, &DT)) || + (IsZext1 && isKnownNonNegative(Y, DL, 0, &AC, &ICmp, &DT))) + IsSignedExt = true; + else + return nullptr; + } + + // Not an extension from the same type? + Type *XTy = X->getType(), *YTy = Y->getType(); + if (XTy != YTy) { + // One of the casts must have one use because we are creating a new cast. + if (!ICmp.getOperand(0)->hasOneUse() && !ICmp.getOperand(1)->hasOneUse()) + return nullptr; + // Extend the narrower operand to the type of the wider operand. + CastInst::CastOps CastOpcode = + IsSignedExt ? Instruction::SExt : Instruction::ZExt; + if (XTy->getScalarSizeInBits() < YTy->getScalarSizeInBits()) + X = Builder.CreateCast(CastOpcode, X, YTy); + else if (YTy->getScalarSizeInBits() < XTy->getScalarSizeInBits()) + Y = Builder.CreateCast(CastOpcode, Y, XTy); + else + return nullptr; + } + + // (zext X) == (zext Y) --> X == Y + // (sext X) == (sext Y) --> X == Y + if (ICmp.isEquality()) + return new ICmpInst(ICmp.getPredicate(), X, Y); + + // A signed comparison of sign extended values simplifies into a + // signed comparison. + if (IsSignedCmp && IsSignedExt) + return new ICmpInst(ICmp.getPredicate(), X, Y); + + // The other three cases all fold into an unsigned comparison. + return new ICmpInst(ICmp.getUnsignedPredicate(), X, Y); + } + + // Below here, we are only folding a compare with constant. + auto *C = dyn_cast<Constant>(ICmp.getOperand(1)); + if (!C) + return nullptr; + + // Compute the constant that would happen if we truncated to SrcTy then + // re-extended to DestTy. + Type *SrcTy = CastOp0->getSrcTy(); + Type *DestTy = CastOp0->getDestTy(); + Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy); + Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy); + + // If the re-extended constant didn't change... + if (Res2 == C) { + if (ICmp.isEquality()) + return new ICmpInst(ICmp.getPredicate(), X, Res1); + + // A signed comparison of sign extended values simplifies into a + // signed comparison. + if (IsSignedExt && IsSignedCmp) + return new ICmpInst(ICmp.getPredicate(), X, Res1); + + // The other three cases all fold into an unsigned comparison. + return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1); + } + + // The re-extended constant changed, partly changed (in the case of a vector), + // or could not be determined to be equal (in the case of a constant + // expression), so the constant cannot be represented in the shorter type. + // All the cases that fold to true or false will have already been handled + // by simplifyICmpInst, so only deal with the tricky case. + if (IsSignedCmp || !IsSignedExt || !isa<ConstantInt>(C)) + return nullptr; + + // Is source op positive? + // icmp ult (sext X), C --> icmp sgt X, -1 + if (ICmp.getPredicate() == ICmpInst::ICMP_ULT) + return new ICmpInst(CmpInst::ICMP_SGT, X, Constant::getAllOnesValue(SrcTy)); + + // Is source op negative? + // icmp ugt (sext X), C --> icmp slt X, 0 + assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); + return new ICmpInst(CmpInst::ICMP_SLT, X, Constant::getNullValue(SrcTy)); +} + +/// Handle icmp (cast x), (cast or constant). +Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { + // If any operand of ICmp is a inttoptr roundtrip cast then remove it as + // icmp compares only pointer's value. + // icmp (inttoptr (ptrtoint p1)), p2 --> icmp p1, p2. + Value *SimplifiedOp0 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(0)); + Value *SimplifiedOp1 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(1)); + if (SimplifiedOp0 || SimplifiedOp1) + return new ICmpInst(ICmp.getPredicate(), + SimplifiedOp0 ? SimplifiedOp0 : ICmp.getOperand(0), + SimplifiedOp1 ? SimplifiedOp1 : ICmp.getOperand(1)); + + auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0)); + if (!CastOp0) + return nullptr; + if (!isa<Constant>(ICmp.getOperand(1)) && !isa<CastInst>(ICmp.getOperand(1))) + return nullptr; + + Value *Op0Src = CastOp0->getOperand(0); + Type *SrcTy = CastOp0->getSrcTy(); + Type *DestTy = CastOp0->getDestTy(); + + // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the + // integer type is the same size as the pointer type. + auto CompatibleSizes = [&](Type *SrcTy, Type *DestTy) { + if (isa<VectorType>(SrcTy)) { + SrcTy = cast<VectorType>(SrcTy)->getElementType(); + DestTy = cast<VectorType>(DestTy)->getElementType(); + } + return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth(); + }; + if (CastOp0->getOpcode() == Instruction::PtrToInt && + CompatibleSizes(SrcTy, DestTy)) { + Value *NewOp1 = nullptr; + if (auto *PtrToIntOp1 = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { + Value *PtrSrc = PtrToIntOp1->getOperand(0); + if (PtrSrc->getType()->getPointerAddressSpace() == + Op0Src->getType()->getPointerAddressSpace()) { + NewOp1 = PtrToIntOp1->getOperand(0); + // If the pointer types don't match, insert a bitcast. + if (Op0Src->getType() != NewOp1->getType()) + NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType()); + } + } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { + NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy); + } + + if (NewOp1) + return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); + } + + if (Instruction *R = foldICmpWithTrunc(ICmp, Builder)) + return R; + + return foldICmpWithZextOrSext(ICmp); +} + +static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { + switch (BinaryOp) { + default: + llvm_unreachable("Unsupported binary op"); + case Instruction::Add: + case Instruction::Sub: + return match(RHS, m_Zero()); + case Instruction::Mul: + return match(RHS, m_One()); + } +} + +OverflowResult +InstCombinerImpl::computeOverflow(Instruction::BinaryOps BinaryOp, + bool IsSigned, Value *LHS, Value *RHS, + Instruction *CxtI) const { + switch (BinaryOp) { + default: + llvm_unreachable("Unsupported binary op"); + case Instruction::Add: + if (IsSigned) + return computeOverflowForSignedAdd(LHS, RHS, CxtI); + else + return computeOverflowForUnsignedAdd(LHS, RHS, CxtI); + case Instruction::Sub: + if (IsSigned) + return computeOverflowForSignedSub(LHS, RHS, CxtI); + else + return computeOverflowForUnsignedSub(LHS, RHS, CxtI); + case Instruction::Mul: + if (IsSigned) + return computeOverflowForSignedMul(LHS, RHS, CxtI); + else + return computeOverflowForUnsignedMul(LHS, RHS, CxtI); + } +} + +bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp, + bool IsSigned, Value *LHS, + Value *RHS, Instruction &OrigI, + Value *&Result, + Constant *&Overflow) { + if (OrigI.isCommutative() && isa<Constant>(LHS) && !isa<Constant>(RHS)) + std::swap(LHS, RHS); + + // If the overflow check was an add followed by a compare, the insertion point + // may be pointing to the compare. We want to insert the new instructions + // before the add in case there are uses of the add between the add and the + // compare. + Builder.SetInsertPoint(&OrigI); + + Type *OverflowTy = Type::getInt1Ty(LHS->getContext()); + if (auto *LHSTy = dyn_cast<VectorType>(LHS->getType())) + OverflowTy = VectorType::get(OverflowTy, LHSTy->getElementCount()); + + if (isNeutralValue(BinaryOp, RHS)) { + Result = LHS; + Overflow = ConstantInt::getFalse(OverflowTy); + return true; + } + + switch (computeOverflow(BinaryOp, IsSigned, LHS, RHS, &OrigI)) { + case OverflowResult::MayOverflow: + return false; + case OverflowResult::AlwaysOverflowsLow: + case OverflowResult::AlwaysOverflowsHigh: + Result = Builder.CreateBinOp(BinaryOp, LHS, RHS); + Result->takeName(&OrigI); + Overflow = ConstantInt::getTrue(OverflowTy); + return true; + case OverflowResult::NeverOverflows: + Result = Builder.CreateBinOp(BinaryOp, LHS, RHS); + Result->takeName(&OrigI); + Overflow = ConstantInt::getFalse(OverflowTy); + if (auto *Inst = dyn_cast<Instruction>(Result)) { + if (IsSigned) + Inst->setHasNoSignedWrap(); + else + Inst->setHasNoUnsignedWrap(); + } + return true; + } + + llvm_unreachable("Unexpected overflow result"); +} + +/// Recognize and process idiom involving test for multiplication +/// overflow. +/// +/// The caller has matched a pattern of the form: +/// I = cmp u (mul(zext A, zext B), V +/// The function checks if this is a test for overflow and if so replaces +/// multiplication with call to 'mul.with.overflow' intrinsic. +/// +/// \param I Compare instruction. +/// \param MulVal Result of 'mult' instruction. It is one of the arguments of +/// the compare instruction. Must be of integer type. +/// \param OtherVal The other argument of compare instruction. +/// \returns Instruction which must replace the compare instruction, NULL if no +/// replacement required. +static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, + Value *OtherVal, + InstCombinerImpl &IC) { + // Don't bother doing this transformation for pointers, don't do it for + // vectors. + if (!isa<IntegerType>(MulVal->getType())) + return nullptr; + + assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal); + assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal); + auto *MulInstr = dyn_cast<Instruction>(MulVal); + if (!MulInstr) + return nullptr; + assert(MulInstr->getOpcode() == Instruction::Mul); + + auto *LHS = cast<ZExtOperator>(MulInstr->getOperand(0)), + *RHS = cast<ZExtOperator>(MulInstr->getOperand(1)); + assert(LHS->getOpcode() == Instruction::ZExt); + assert(RHS->getOpcode() == Instruction::ZExt); + Value *A = LHS->getOperand(0), *B = RHS->getOperand(0); + + // Calculate type and width of the result produced by mul.with.overflow. + Type *TyA = A->getType(), *TyB = B->getType(); + unsigned WidthA = TyA->getPrimitiveSizeInBits(), + WidthB = TyB->getPrimitiveSizeInBits(); + unsigned MulWidth; + Type *MulType; + if (WidthB > WidthA) { + MulWidth = WidthB; + MulType = TyB; + } else { + MulWidth = WidthA; + MulType = TyA; + } + + // In order to replace the original mul with a narrower mul.with.overflow, + // all uses must ignore upper bits of the product. The number of used low + // bits must be not greater than the width of mul.with.overflow. + if (MulVal->hasNUsesOrMore(2)) + for (User *U : MulVal->users()) { + if (U == &I) + continue; + if (TruncInst *TI = dyn_cast<TruncInst>(U)) { + // Check if truncation ignores bits above MulWidth. + unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits(); + if (TruncWidth > MulWidth) + return nullptr; + } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) { + // Check if AND ignores bits above MulWidth. + if (BO->getOpcode() != Instruction::And) + return nullptr; + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) { + const APInt &CVal = CI->getValue(); + if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth) + return nullptr; + } else { + // In this case we could have the operand of the binary operation + // being defined in another block, and performing the replacement + // could break the dominance relation. + return nullptr; + } + } else { + // Other uses prohibit this transformation. + return nullptr; + } + } + + // Recognize patterns + switch (I.getPredicate()) { + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits. + ConstantInt *CI; + Value *ValToMask; + if (match(OtherVal, m_And(m_Value(ValToMask), m_ConstantInt(CI)))) { + if (ValToMask != MulVal) + return nullptr; + const APInt &CVal = CI->getValue() + 1; + if (CVal.isPowerOf2()) { + unsigned MaskWidth = CVal.logBase2(); + if (MaskWidth == MulWidth) + break; // Recognized + } + } + return nullptr; + + case ICmpInst::ICMP_UGT: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ugt mulval, max + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { + APInt MaxVal = APInt::getMaxValue(MulWidth); + MaxVal = MaxVal.zext(CI->getBitWidth()); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + case ICmpInst::ICMP_UGE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp uge mulval, max+1 + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { + APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + case ICmpInst::ICMP_ULE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ule mulval, max + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { + APInt MaxVal = APInt::getMaxValue(MulWidth); + MaxVal = MaxVal.zext(CI->getBitWidth()); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + case ICmpInst::ICMP_ULT: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ule mulval, max + 1 + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { + APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + default: + return nullptr; + } + + InstCombiner::BuilderTy &Builder = IC.Builder; + Builder.SetInsertPoint(MulInstr); + + // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B) + Value *MulA = A, *MulB = B; + if (WidthA < MulWidth) + MulA = Builder.CreateZExt(A, MulType); + if (WidthB < MulWidth) + MulB = Builder.CreateZExt(B, MulType); + Function *F = Intrinsic::getDeclaration( + I.getModule(), Intrinsic::umul_with_overflow, MulType); + CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul"); + IC.addToWorklist(MulInstr); + + // If there are uses of mul result other than the comparison, we know that + // they are truncation or binary AND. Change them to use result of + // mul.with.overflow and adjust properly mask/size. + if (MulVal->hasNUsesOrMore(2)) { + Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value"); + for (User *U : make_early_inc_range(MulVal->users())) { + if (U == &I || U == OtherVal) + continue; + if (TruncInst *TI = dyn_cast<TruncInst>(U)) { + if (TI->getType()->getPrimitiveSizeInBits() == MulWidth) + IC.replaceInstUsesWith(*TI, Mul); + else + TI->setOperand(0, Mul); + } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) { + assert(BO->getOpcode() == Instruction::And); + // Replace (mul & mask) --> zext (mul.with.overflow & short_mask) + ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); + APInt ShortMask = CI->getValue().trunc(MulWidth); + Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask); + Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType()); + IC.replaceInstUsesWith(*BO, Zext); + } else { + llvm_unreachable("Unexpected Binary operation"); + } + IC.addToWorklist(cast<Instruction>(U)); + } + } + if (isa<Instruction>(OtherVal)) + IC.addToWorklist(cast<Instruction>(OtherVal)); + + // The original icmp gets replaced with the overflow value, maybe inverted + // depending on predicate. + bool Inverse = false; + switch (I.getPredicate()) { + case ICmpInst::ICMP_NE: + break; + case ICmpInst::ICMP_EQ: + Inverse = true; + break; + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (I.getOperand(0) == MulVal) + break; + Inverse = true; + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + if (I.getOperand(1) == MulVal) + break; + Inverse = true; + break; + default: + llvm_unreachable("Unexpected predicate"); + } + if (Inverse) { + Value *Res = Builder.CreateExtractValue(Call, 1); + return BinaryOperator::CreateNot(Res); + } + + return ExtractValueInst::Create(Call, 1); +} + +/// When performing a comparison against a constant, it is possible that not all +/// the bits in the LHS are demanded. This helper method computes the mask that +/// IS demanded. +static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { + const APInt *RHS; + if (!match(I.getOperand(1), m_APInt(RHS))) + return APInt::getAllOnes(BitWidth); + + // If this is a normal comparison, it demands all bits. If it is a sign bit + // comparison, it only demands the sign bit. + bool UnusedBit; + if (InstCombiner::isSignBitCheck(I.getPredicate(), *RHS, UnusedBit)) + return APInt::getSignMask(BitWidth); + + switch (I.getPredicate()) { + // For a UGT comparison, we don't care about any bits that + // correspond to the trailing ones of the comparand. The value of these + // bits doesn't impact the outcome of the comparison, because any value + // greater than the RHS must differ in a bit higher than these due to carry. + case ICmpInst::ICMP_UGT: + return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingOnes()); + + // Similarly, for a ULT comparison, we don't care about the trailing zeros. + // Any value less than the RHS must differ in a higher bit because of carries. + case ICmpInst::ICMP_ULT: + return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingZeros()); + + default: + return APInt::getAllOnes(BitWidth); + } +} + +/// Check if the order of \p Op0 and \p Op1 as operands in an ICmpInst +/// should be swapped. +/// The decision is based on how many times these two operands are reused +/// as subtract operands and their positions in those instructions. +/// The rationale is that several architectures use the same instruction for +/// both subtract and cmp. Thus, it is better if the order of those operands +/// match. +/// \return true if Op0 and Op1 should be swapped. +static bool swapMayExposeCSEOpportunities(const Value *Op0, const Value *Op1) { + // Filter out pointer values as those cannot appear directly in subtract. + // FIXME: we may want to go through inttoptrs or bitcasts. + if (Op0->getType()->isPointerTy()) + return false; + // If a subtract already has the same operands as a compare, swapping would be + // bad. If a subtract has the same operands as a compare but in reverse order, + // then swapping is good. + int GoodToSwap = 0; + for (const User *U : Op0->users()) { + if (match(U, m_Sub(m_Specific(Op1), m_Specific(Op0)))) + GoodToSwap++; + else if (match(U, m_Sub(m_Specific(Op0), m_Specific(Op1)))) + GoodToSwap--; + } + return GoodToSwap > 0; +} + +/// Check that one use is in the same block as the definition and all +/// other uses are in blocks dominated by a given block. +/// +/// \param DI Definition +/// \param UI Use +/// \param DB Block that must dominate all uses of \p DI outside +/// the parent block +/// \return true when \p UI is the only use of \p DI in the parent block +/// and all other uses of \p DI are in blocks dominated by \p DB. +/// +bool InstCombinerImpl::dominatesAllUses(const Instruction *DI, + const Instruction *UI, + const BasicBlock *DB) const { + assert(DI && UI && "Instruction not defined\n"); + // Ignore incomplete definitions. + if (!DI->getParent()) + return false; + // DI and UI must be in the same block. + if (DI->getParent() != UI->getParent()) + return false; + // Protect from self-referencing blocks. + if (DI->getParent() == DB) + return false; + for (const User *U : DI->users()) { + auto *Usr = cast<Instruction>(U); + if (Usr != UI && !DT.dominates(DB, Usr->getParent())) + return false; + } + return true; +} + +/// Return true when the instruction sequence within a block is select-cmp-br. +static bool isChainSelectCmpBranch(const SelectInst *SI) { + const BasicBlock *BB = SI->getParent(); + if (!BB) + return false; + auto *BI = dyn_cast_or_null<BranchInst>(BB->getTerminator()); + if (!BI || BI->getNumSuccessors() != 2) + return false; + auto *IC = dyn_cast<ICmpInst>(BI->getCondition()); + if (!IC || (IC->getOperand(0) != SI && IC->getOperand(1) != SI)) + return false; + return true; +} + +/// True when a select result is replaced by one of its operands +/// in select-icmp sequence. This will eventually result in the elimination +/// of the select. +/// +/// \param SI Select instruction +/// \param Icmp Compare instruction +/// \param SIOpd Operand that replaces the select +/// +/// Notes: +/// - The replacement is global and requires dominator information +/// - The caller is responsible for the actual replacement +/// +/// Example: +/// +/// entry: +/// %4 = select i1 %3, %C* %0, %C* null +/// %5 = icmp eq %C* %4, null +/// br i1 %5, label %9, label %7 +/// ... +/// ; <label>:7 ; preds = %entry +/// %8 = getelementptr inbounds %C* %4, i64 0, i32 0 +/// ... +/// +/// can be transformed to +/// +/// %5 = icmp eq %C* %0, null +/// %6 = select i1 %3, i1 %5, i1 true +/// br i1 %6, label %9, label %7 +/// ... +/// ; <label>:7 ; preds = %entry +/// %8 = getelementptr inbounds %C* %0, i64 0, i32 0 // replace by %0! +/// +/// Similar when the first operand of the select is a constant or/and +/// the compare is for not equal rather than equal. +/// +/// NOTE: The function is only called when the select and compare constants +/// are equal, the optimization can work only for EQ predicates. This is not a +/// major restriction since a NE compare should be 'normalized' to an equal +/// compare, which usually happens in the combiner and test case +/// select-cmp-br.ll checks for it. +bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI, + const ICmpInst *Icmp, + const unsigned SIOpd) { + assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!"); + if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) { + BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1); + // The check for the single predecessor is not the best that can be + // done. But it protects efficiently against cases like when SI's + // home block has two successors, Succ and Succ1, and Succ1 predecessor + // of Succ. Then SI can't be replaced by SIOpd because the use that gets + // replaced can be reached on either path. So the uniqueness check + // guarantees that the path all uses of SI (outside SI's parent) are on + // is disjoint from all other paths out of SI. But that information + // is more expensive to compute, and the trade-off here is in favor + // of compile-time. It should also be noticed that we check for a single + // predecessor and not only uniqueness. This to handle the situation when + // Succ and Succ1 points to the same basic block. + if (Succ->getSinglePredecessor() && dominatesAllUses(SI, Icmp, Succ)) { + NumSel++; + SI->replaceUsesOutsideBlock(SI->getOperand(SIOpd), SI->getParent()); + return true; + } + } + return false; +} + +/// Try to fold the comparison based on range information we can get by checking +/// whether bits are known to be zero or one in the inputs. +Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = Op0->getType(); + ICmpInst::Predicate Pred = I.getPredicate(); + + // Get scalar or pointer size. + unsigned BitWidth = Ty->isIntOrIntVectorTy() + ? Ty->getScalarSizeInBits() + : DL.getPointerTypeSizeInBits(Ty->getScalarType()); + + if (!BitWidth) + return nullptr; + + KnownBits Op0Known(BitWidth); + KnownBits Op1Known(BitWidth); + + if (SimplifyDemandedBits(&I, 0, + getDemandedBitsLHSMask(I, BitWidth), + Op0Known, 0)) + return &I; + + if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0)) + return &I; + + // Given the known and unknown bits, compute a range that the LHS could be + // in. Compute the Min, Max and RHS values based on the known bits. For the + // EQ and NE we use unsigned values. + APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); + APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); + if (I.isSigned()) { + Op0Min = Op0Known.getSignedMinValue(); + Op0Max = Op0Known.getSignedMaxValue(); + Op1Min = Op1Known.getSignedMinValue(); + Op1Max = Op1Known.getSignedMaxValue(); + } else { + Op0Min = Op0Known.getMinValue(); + Op0Max = Op0Known.getMaxValue(); + Op1Min = Op1Known.getMinValue(); + Op1Max = Op1Known.getMaxValue(); + } + + // If Min and Max are known to be the same, then SimplifyDemandedBits figured + // out that the LHS or RHS is a constant. Constant fold this now, so that + // code below can assume that Min != Max. + if (!isa<Constant>(Op0) && Op0Min == Op0Max) + return new ICmpInst(Pred, ConstantExpr::getIntegerValue(Ty, Op0Min), Op1); + if (!isa<Constant>(Op1) && Op1Min == Op1Max) + return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min)); + + // Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a + // min/max canonical compare with some other compare. That could lead to + // conflict with select canonicalization and infinite looping. + // FIXME: This constraint may go away if min/max intrinsics are canonical. + auto isMinMaxCmp = [&](Instruction &Cmp) { + if (!Cmp.hasOneUse()) + return false; + Value *A, *B; + SelectPatternFlavor SPF = matchSelectPattern(Cmp.user_back(), A, B).Flavor; + if (!SelectPatternResult::isMinOrMax(SPF)) + return false; + return match(Op0, m_MaxOrMin(m_Value(), m_Value())) || + match(Op1, m_MaxOrMin(m_Value(), m_Value())); + }; + if (!isMinMaxCmp(I)) { + switch (Pred) { + default: + break; + case ICmpInst::ICMP_ULT: { + if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A <u C -> A == C-1 if min(A)+1 == C + if (*CmpC == Op0Min + 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC - 1)); + // X <u C --> X == 0, if the number of zero bits in the bottom of X + // exceeds the log2 of C. + if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2()) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Constant::getNullValue(Op1->getType())); + } + break; + } + case ICmpInst::ICMP_UGT: { + if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A >u C -> A == C+1 if max(a)-1 == C + if (*CmpC == Op0Max - 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC + 1)); + // X >u C --> X != 0, if the number of zero bits in the bottom of X + // exceeds the log2 of C. + if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits()) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, + Constant::getNullValue(Op1->getType())); + } + break; + } + case ICmpInst::ICMP_SLT: { + if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC - 1)); + } + break; + } + case ICmpInst::ICMP_SGT: { + if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC + 1)); + } + break; + } + } + } + + // Based on the range information we know about the LHS, see if we can + // simplify this comparison. For example, (x&4) < 8 is always true. + switch (Pred) { + default: + llvm_unreachable("Unknown icmp opcode!"); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: { + if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) + return replaceInstUsesWith( + I, ConstantInt::getBool(I.getType(), Pred == CmpInst::ICMP_NE)); + + // If all bits are known zero except for one, then we know at most one bit + // is set. If the comparison is against zero, then this is a check to see if + // *that* bit is set. + APInt Op0KnownZeroInverted = ~Op0Known.Zero; + if (Op1Known.isZero()) { + // If the LHS is an AND with the same constant, look through it. + Value *LHS = nullptr; + const APInt *LHSC; + if (!match(Op0, m_And(m_Value(LHS), m_APInt(LHSC))) || + *LHSC != Op0KnownZeroInverted) + LHS = Op0; + + Value *X; + const APInt *C1; + if (match(LHS, m_Shl(m_Power2(C1), m_Value(X)))) { + Type *XTy = X->getType(); + unsigned Log2C1 = C1->countTrailingZeros(); + APInt C2 = Op0KnownZeroInverted; + APInt C2Pow2 = (C2 & ~(*C1 - 1)) + *C1; + if (C2Pow2.isPowerOf2()) { + // iff (C1 is pow2) & ((C2 & ~(C1-1)) + C1) is pow2): + // ((C1 << X) & C2) == 0 -> X >= (Log2(C2+C1) - Log2(C1)) + // ((C1 << X) & C2) != 0 -> X < (Log2(C2+C1) - Log2(C1)) + unsigned Log2C2 = C2Pow2.countTrailingZeros(); + auto *CmpC = ConstantInt::get(XTy, Log2C2 - Log2C1); + auto NewPred = + Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGE : CmpInst::ICMP_ULT; + return new ICmpInst(NewPred, X, CmpC); + } + } + } + break; + } + case ICmpInst::ICMP_ULT: { + if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + } + case ICmpInst::ICMP_UGT: { + if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + } + case ICmpInst::ICMP_SLT: { + if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + } + case ICmpInst::ICMP_SGT: { + if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + } + case ICmpInst::ICMP_SGE: + assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); + if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + break; + case ICmpInst::ICMP_SLE: + assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); + if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + break; + case ICmpInst::ICMP_UGE: + assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); + if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + break; + case ICmpInst::ICMP_ULE: + assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); + if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + break; + } + + // Turn a signed comparison into an unsigned one if both operands are known to + // have the same sign. + if (I.isSigned() && + ((Op0Known.Zero.isNegative() && Op1Known.Zero.isNegative()) || + (Op0Known.One.isNegative() && Op1Known.One.isNegative()))) + return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); + + return nullptr; +} + +/// If one operand of an icmp is effectively a bool (value range of {0,1}), +/// then try to reduce patterns based on that limit. +static Instruction *foldICmpUsingBoolRange(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + Value *X, *Y; + ICmpInst::Predicate Pred; + + // X must be 0 and bool must be true for "ULT": + // X <u (zext i1 Y) --> (X == 0) & Y + if (match(&I, m_c_ICmp(Pred, m_Value(X), m_OneUse(m_ZExt(m_Value(Y))))) && + Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULT) + return BinaryOperator::CreateAnd(Builder.CreateIsNull(X), Y); + + // X must be 0 or bool must be true for "ULE": + // X <=u (sext i1 Y) --> (X == 0) | Y + if (match(&I, m_c_ICmp(Pred, m_Value(X), m_OneUse(m_SExt(m_Value(Y))))) && + Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE) + return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y); + + return nullptr; +} + +llvm::Optional<std::pair<CmpInst::Predicate, Constant *>> +InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, + Constant *C) { + assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && + "Only for relational integer predicates."); + + Type *Type = C->getType(); + bool IsSigned = ICmpInst::isSigned(Pred); + + CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred); + bool WillIncrement = + UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT; + + // Check if the constant operand can be safely incremented/decremented + // without overflowing/underflowing. + auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) { + return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned); + }; + + Constant *SafeReplacementConstant = nullptr; + if (auto *CI = dyn_cast<ConstantInt>(C)) { + // Bail out if the constant can't be safely incremented/decremented. + if (!ConstantIsOk(CI)) + return llvm::None; + } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) { + unsigned NumElts = FVTy->getNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = C->getAggregateElement(i); + if (!Elt) + return llvm::None; + + if (isa<UndefValue>(Elt)) + continue; + + // Bail out if we can't determine if this constant is min/max or if we + // know that this constant is min/max. + auto *CI = dyn_cast<ConstantInt>(Elt); + if (!CI || !ConstantIsOk(CI)) + return llvm::None; + + if (!SafeReplacementConstant) + SafeReplacementConstant = CI; + } + } else { + // ConstantExpr? + return llvm::None; + } + + // It may not be safe to change a compare predicate in the presence of + // undefined elements, so replace those elements with the first safe constant + // that we found. + // TODO: in case of poison, it is safe; let's replace undefs only. + if (C->containsUndefOrPoisonElement()) { + assert(SafeReplacementConstant && "Replacement constant not set"); + C = Constant::replaceUndefsWith(C, SafeReplacementConstant); + } + + CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred); + + // Increment or decrement the constant. + Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true); + Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne); + + return std::make_pair(NewPred, NewC); +} + +/// If we have an icmp le or icmp ge instruction with a constant operand, turn +/// it into the appropriate icmp lt or icmp gt instruction. This transform +/// allows them to be folded in visitICmpInst. +static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { + ICmpInst::Predicate Pred = I.getPredicate(); + if (ICmpInst::isEquality(Pred) || !ICmpInst::isIntPredicate(Pred) || + InstCombiner::isCanonicalPredicate(Pred)) + return nullptr; + + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + auto *Op1C = dyn_cast<Constant>(Op1); + if (!Op1C) + return nullptr; + + auto FlippedStrictness = + InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C); + if (!FlippedStrictness) + return nullptr; + + return new ICmpInst(FlippedStrictness->first, Op0, FlippedStrictness->second); +} + +/// If we have a comparison with a non-canonical predicate, if we can update +/// all the users, invert the predicate and adjust all the users. +CmpInst *InstCombinerImpl::canonicalizeICmpPredicate(CmpInst &I) { + // Is the predicate already canonical? + CmpInst::Predicate Pred = I.getPredicate(); + if (InstCombiner::isCanonicalPredicate(Pred)) + return nullptr; + + // Can all users be adjusted to predicate inversion? + if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) + return nullptr; + + // Ok, we can canonicalize comparison! + // Let's first invert the comparison's predicate. + I.setPredicate(CmpInst::getInversePredicate(Pred)); + I.setName(I.getName() + ".not"); + + // And, adapt users. + freelyInvertAllUsersOf(&I); + + return &I; +} + +/// Integer compare with boolean values can always be turned into bitwise ops. +static Instruction *canonicalizeICmpBool(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + Value *A = I.getOperand(0), *B = I.getOperand(1); + assert(A->getType()->isIntOrIntVectorTy(1) && "Bools only"); + + // A boolean compared to true/false can be simplified to Op0/true/false in + // 14 out of the 20 (10 predicates * 2 constants) possible combinations. + // Cases not handled by InstSimplify are always 'not' of Op0. + if (match(B, m_Zero())) { + switch (I.getPredicate()) { + case CmpInst::ICMP_EQ: // A == 0 -> !A + case CmpInst::ICMP_ULE: // A <=u 0 -> !A + case CmpInst::ICMP_SGE: // A >=s 0 -> !A + return BinaryOperator::CreateNot(A); + default: + llvm_unreachable("ICmp i1 X, C not simplified as expected."); + } + } else if (match(B, m_One())) { + switch (I.getPredicate()) { + case CmpInst::ICMP_NE: // A != 1 -> !A + case CmpInst::ICMP_ULT: // A <u 1 -> !A + case CmpInst::ICMP_SGT: // A >s -1 -> !A + return BinaryOperator::CreateNot(A); + default: + llvm_unreachable("ICmp i1 X, C not simplified as expected."); + } + } + + switch (I.getPredicate()) { + default: + llvm_unreachable("Invalid icmp instruction!"); + case ICmpInst::ICMP_EQ: + // icmp eq i1 A, B -> ~(A ^ B) + return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); + + case ICmpInst::ICMP_NE: + // icmp ne i1 A, B -> A ^ B + return BinaryOperator::CreateXor(A, B); + + case ICmpInst::ICMP_UGT: + // icmp ugt -> icmp ult + std::swap(A, B); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_ULT: + // icmp ult i1 A, B -> ~A & B + return BinaryOperator::CreateAnd(Builder.CreateNot(A), B); + + case ICmpInst::ICMP_SGT: + // icmp sgt -> icmp slt + std::swap(A, B); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_SLT: + // icmp slt i1 A, B -> A & ~B + return BinaryOperator::CreateAnd(Builder.CreateNot(B), A); + + case ICmpInst::ICMP_UGE: + // icmp uge -> icmp ule + std::swap(A, B); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_ULE: + // icmp ule i1 A, B -> ~A | B + return BinaryOperator::CreateOr(Builder.CreateNot(A), B); + + case ICmpInst::ICMP_SGE: + // icmp sge -> icmp sle + std::swap(A, B); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_SLE: + // icmp sle i1 A, B -> A | ~B + return BinaryOperator::CreateOr(Builder.CreateNot(B), A); + } +} + +// Transform pattern like: +// (1 << Y) u<= X or ~(-1 << Y) u< X or ((1 << Y)+(-1)) u< X +// (1 << Y) u> X or ~(-1 << Y) u>= X or ((1 << Y)+(-1)) u>= X +// Into: +// (X l>> Y) != 0 +// (X l>> Y) == 0 +static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred, NewPred; + Value *X, *Y; + if (match(&Cmp, + m_c_ICmp(Pred, m_OneUse(m_Shl(m_One(), m_Value(Y))), m_Value(X)))) { + switch (Pred) { + case ICmpInst::ICMP_ULE: + NewPred = ICmpInst::ICMP_NE; + break; + case ICmpInst::ICMP_UGT: + NewPred = ICmpInst::ICMP_EQ; + break; + default: + return nullptr; + } + } else if (match(&Cmp, m_c_ICmp(Pred, + m_OneUse(m_CombineOr( + m_Not(m_Shl(m_AllOnes(), m_Value(Y))), + m_Add(m_Shl(m_One(), m_Value(Y)), + m_AllOnes()))), + m_Value(X)))) { + // The variant with 'add' is not canonical, (the variant with 'not' is) + // we only get it because it has extra uses, and can't be canonicalized, + + switch (Pred) { + case ICmpInst::ICMP_ULT: + NewPred = ICmpInst::ICMP_NE; + break; + case ICmpInst::ICMP_UGE: + NewPred = ICmpInst::ICMP_EQ; + break; + default: + return nullptr; + } + } else + return nullptr; + + Value *NewX = Builder.CreateLShr(X, Y, X->getName() + ".highbits"); + Constant *Zero = Constant::getNullValue(NewX->getType()); + return CmpInst::Create(Instruction::ICmp, NewPred, NewX, Zero); +} + +static Instruction *foldVectorCmp(CmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + const CmpInst::Predicate Pred = Cmp.getPredicate(); + Value *LHS = Cmp.getOperand(0), *RHS = Cmp.getOperand(1); + Value *V1, *V2; + ArrayRef<int> M; + if (!match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(M)))) + return nullptr; + + // If both arguments of the cmp are shuffles that use the same mask and + // shuffle within a single vector, move the shuffle after the cmp: + // cmp (shuffle V1, M), (shuffle V2, M) --> shuffle (cmp V1, V2), M + Type *V1Ty = V1->getType(); + if (match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(M))) && + V1Ty == V2->getType() && (LHS->hasOneUse() || RHS->hasOneUse())) { + Value *NewCmp = Builder.CreateCmp(Pred, V1, V2); + return new ShuffleVectorInst(NewCmp, M); + } + + // Try to canonicalize compare with splatted operand and splat constant. + // TODO: We could generalize this for more than splats. See/use the code in + // InstCombiner::foldVectorBinop(). + Constant *C; + if (!LHS->hasOneUse() || !match(RHS, m_Constant(C))) + return nullptr; + + // Length-changing splats are ok, so adjust the constants as needed: + // cmp (shuffle V1, M), C --> shuffle (cmp V1, C'), M + Constant *ScalarC = C->getSplatValue(/* AllowUndefs */ true); + int MaskSplatIndex; + if (ScalarC && match(M, m_SplatOrUndefMask(MaskSplatIndex))) { + // We allow undefs in matching, but this transform removes those for safety. + // Demanded elements analysis should be able to recover some/all of that. + C = ConstantVector::getSplat(cast<VectorType>(V1Ty)->getElementCount(), + ScalarC); + SmallVector<int, 8> NewM(M.size(), MaskSplatIndex); + Value *NewCmp = Builder.CreateCmp(Pred, V1, C); + return new ShuffleVectorInst(NewCmp, NewM); + } + + return nullptr; +} + +// extract(uadd.with.overflow(A, B), 0) ult A +// -> extract(uadd.with.overflow(A, B), 1) +static Instruction *foldICmpOfUAddOv(ICmpInst &I) { + CmpInst::Predicate Pred = I.getPredicate(); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + Value *UAddOv; + Value *A, *B; + auto UAddOvResultPat = m_ExtractValue<0>( + m_Intrinsic<Intrinsic::uadd_with_overflow>(m_Value(A), m_Value(B))); + if (match(Op0, UAddOvResultPat) && + ((Pred == ICmpInst::ICMP_ULT && (Op1 == A || Op1 == B)) || + (Pred == ICmpInst::ICMP_EQ && match(Op1, m_ZeroInt()) && + (match(A, m_One()) || match(B, m_One()))) || + (Pred == ICmpInst::ICMP_NE && match(Op1, m_AllOnes()) && + (match(A, m_AllOnes()) || match(B, m_AllOnes()))))) + // extract(uadd.with.overflow(A, B), 0) < A + // extract(uadd.with.overflow(A, 1), 0) == 0 + // extract(uadd.with.overflow(A, -1), 0) != -1 + UAddOv = cast<ExtractValueInst>(Op0)->getAggregateOperand(); + else if (match(Op1, UAddOvResultPat) && + Pred == ICmpInst::ICMP_UGT && (Op0 == A || Op0 == B)) + // A > extract(uadd.with.overflow(A, B), 0) + UAddOv = cast<ExtractValueInst>(Op1)->getAggregateOperand(); + else + return nullptr; + + return ExtractValueInst::Create(UAddOv, 1); +} + +static Instruction *foldICmpInvariantGroup(ICmpInst &I) { + if (!I.getOperand(0)->getType()->isPointerTy() || + NullPointerIsDefined( + I.getParent()->getParent(), + I.getOperand(0)->getType()->getPointerAddressSpace())) { + return nullptr; + } + Instruction *Op; + if (match(I.getOperand(0), m_Instruction(Op)) && + match(I.getOperand(1), m_Zero()) && + Op->isLaunderOrStripInvariantGroup()) { + return ICmpInst::Create(Instruction::ICmp, I.getPredicate(), + Op->getOperand(0), I.getOperand(1)); + } + return nullptr; +} + +/// This function folds patterns produced by lowering of reduce idioms, such as +/// llvm.vector.reduce.and which are lowered into instruction chains. This code +/// attempts to generate fewer number of scalar comparisons instead of vector +/// comparisons when possible. +static Instruction *foldReductionIdiom(ICmpInst &I, + InstCombiner::BuilderTy &Builder, + const DataLayout &DL) { + if (I.getType()->isVectorTy()) + return nullptr; + ICmpInst::Predicate OuterPred, InnerPred; + Value *LHS, *RHS; + + // Match lowering of @llvm.vector.reduce.and. Turn + /// %vec_ne = icmp ne <8 x i8> %lhs, %rhs + /// %scalar_ne = bitcast <8 x i1> %vec_ne to i8 + /// %res = icmp <pred> i8 %scalar_ne, 0 + /// + /// into + /// + /// %lhs.scalar = bitcast <8 x i8> %lhs to i64 + /// %rhs.scalar = bitcast <8 x i8> %rhs to i64 + /// %res = icmp <pred> i64 %lhs.scalar, %rhs.scalar + /// + /// for <pred> in {ne, eq}. + if (!match(&I, m_ICmp(OuterPred, + m_OneUse(m_BitCast(m_OneUse( + m_ICmp(InnerPred, m_Value(LHS), m_Value(RHS))))), + m_Zero()))) + return nullptr; + auto *LHSTy = dyn_cast<FixedVectorType>(LHS->getType()); + if (!LHSTy || !LHSTy->getElementType()->isIntegerTy()) + return nullptr; + unsigned NumBits = + LHSTy->getNumElements() * LHSTy->getElementType()->getIntegerBitWidth(); + // TODO: Relax this to "not wider than max legal integer type"? + if (!DL.isLegalInteger(NumBits)) + return nullptr; + + if (ICmpInst::isEquality(OuterPred) && InnerPred == ICmpInst::ICMP_NE) { + auto *ScalarTy = Builder.getIntNTy(NumBits); + LHS = Builder.CreateBitCast(LHS, ScalarTy, LHS->getName() + ".scalar"); + RHS = Builder.CreateBitCast(RHS, ScalarTy, RHS->getName() + ".scalar"); + return ICmpInst::Create(Instruction::ICmp, OuterPred, LHS, RHS, + I.getName()); + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { + bool Changed = false; + const SimplifyQuery Q = SQ.getWithInstruction(&I); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + unsigned Op0Cplxity = getComplexity(Op0); + unsigned Op1Cplxity = getComplexity(Op1); + + /// Orders the operands of the compare so that they are listed from most + /// complex to least complex. This puts constants before unary operators, + /// before binary operators. + if (Op0Cplxity < Op1Cplxity || + (Op0Cplxity == Op1Cplxity && swapMayExposeCSEOpportunities(Op0, Op1))) { + I.swapOperands(); + std::swap(Op0, Op1); + Changed = true; + } + + if (Value *V = simplifyICmpInst(I.getPredicate(), Op0, Op1, Q)) + return replaceInstUsesWith(I, V); + + // Comparing -val or val with non-zero is the same as just comparing val + // ie, abs(val) != 0 -> val != 0 + if (I.getPredicate() == ICmpInst::ICMP_NE && match(Op1, m_Zero())) { + Value *Cond, *SelectTrue, *SelectFalse; + if (match(Op0, m_Select(m_Value(Cond), m_Value(SelectTrue), + m_Value(SelectFalse)))) { + if (Value *V = dyn_castNegVal(SelectTrue)) { + if (V == SelectFalse) + return CmpInst::Create(Instruction::ICmp, I.getPredicate(), V, Op1); + } + else if (Value *V = dyn_castNegVal(SelectFalse)) { + if (V == SelectTrue) + return CmpInst::Create(Instruction::ICmp, I.getPredicate(), V, Op1); + } + } + } + + if (Op0->getType()->isIntOrIntVectorTy(1)) + if (Instruction *Res = canonicalizeICmpBool(I, Builder)) + return Res; + + if (Instruction *Res = canonicalizeCmpWithConstant(I)) + return Res; + + if (Instruction *Res = canonicalizeICmpPredicate(I)) + return Res; + + if (Instruction *Res = foldICmpWithConstant(I)) + return Res; + + if (Instruction *Res = foldICmpWithDominatingICmp(I)) + return Res; + + if (Instruction *Res = foldICmpUsingBoolRange(I, Builder)) + return Res; + + if (Instruction *Res = foldICmpUsingKnownBits(I)) + return Res; + + // Test if the ICmpInst instruction is used exclusively by a select as + // part of a minimum or maximum operation. If so, refrain from doing + // any other folding. This helps out other analyses which understand + // non-obfuscated minimum and maximum idioms, such as ScalarEvolution + // and CodeGen. And in this case, at least one of the comparison + // operands has at least one user besides the compare (the select), + // which would often largely negate the benefit of folding anyway. + // + // Do the same for the other patterns recognized by matchSelectPattern. + if (I.hasOneUse()) + if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) { + Value *A, *B; + SelectPatternResult SPR = matchSelectPattern(SI, A, B); + if (SPR.Flavor != SPF_UNKNOWN) + return nullptr; + } + + // Do this after checking for min/max to prevent infinite looping. + if (Instruction *Res = foldICmpWithZero(I)) + return Res; + + // FIXME: We only do this after checking for min/max to prevent infinite + // looping caused by a reverse canonicalization of these patterns for min/max. + // FIXME: The organization of folds is a mess. These would naturally go into + // canonicalizeCmpWithConstant(), but we can't move all of the above folds + // down here after the min/max restriction. + ICmpInst::Predicate Pred = I.getPredicate(); + const APInt *C; + if (match(Op1, m_APInt(C))) { + // For i32: x >u 2147483647 -> x <s 0 -> true if sign bit set + if (Pred == ICmpInst::ICMP_UGT && C->isMaxSignedValue()) { + Constant *Zero = Constant::getNullValue(Op0->getType()); + return new ICmpInst(ICmpInst::ICMP_SLT, Op0, Zero); + } + + // For i32: x <u 2147483648 -> x >s -1 -> true if sign bit clear + if (Pred == ICmpInst::ICMP_ULT && C->isMinSignedValue()) { + Constant *AllOnes = Constant::getAllOnesValue(Op0->getType()); + return new ICmpInst(ICmpInst::ICMP_SGT, Op0, AllOnes); + } + } + + // The folds in here may rely on wrapping flags and special constants, so + // they can break up min/max idioms in some cases but not seemingly similar + // patterns. + // FIXME: It may be possible to enhance select folding to make this + // unnecessary. It may also be moot if we canonicalize to min/max + // intrinsics. + if (Instruction *Res = foldICmpBinOp(I, Q)) + return Res; + + if (Instruction *Res = foldICmpInstWithConstant(I)) + return Res; + + // Try to match comparison as a sign bit test. Intentionally do this after + // foldICmpInstWithConstant() to potentially let other folds to happen first. + if (Instruction *New = foldSignBitTest(I)) + return New; + + if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) + return Res; + + // Try to optimize 'icmp GEP, P' or 'icmp P, GEP'. + if (auto *GEP = dyn_cast<GEPOperator>(Op0)) + if (Instruction *NI = foldGEPICmp(GEP, Op1, I.getPredicate(), I)) + return NI; + if (auto *GEP = dyn_cast<GEPOperator>(Op1)) + if (Instruction *NI = foldGEPICmp(GEP, Op0, I.getSwappedPredicate(), I)) + return NI; + + if (auto *SI = dyn_cast<SelectInst>(Op0)) + if (Instruction *NI = foldSelectICmp(I.getPredicate(), SI, Op1, I)) + return NI; + if (auto *SI = dyn_cast<SelectInst>(Op1)) + if (Instruction *NI = foldSelectICmp(I.getSwappedPredicate(), SI, Op0, I)) + return NI; + + // Try to optimize equality comparisons against alloca-based pointers. + if (Op0->getType()->isPointerTy() && I.isEquality()) { + assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?"); + if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op0))) + if (Instruction *New = foldAllocaCmp(I, Alloca)) + return New; + if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op1))) + if (Instruction *New = foldAllocaCmp(I, Alloca)) + return New; + } + + if (Instruction *Res = foldICmpBitCast(I)) + return Res; + + // TODO: Hoist this above the min/max bailout. + if (Instruction *R = foldICmpWithCastOp(I)) + return R; + + if (Instruction *Res = foldICmpWithMinMax(I)) + return Res; + + { + Value *A, *B; + // Transform (A & ~B) == 0 --> (A & B) != 0 + // and (A & ~B) != 0 --> (A & B) == 0 + // if A is a power of 2. + if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_Zero()) && + isKnownToBeAPowerOfTwo(A, false, 0, &I) && I.isEquality()) + return new ICmpInst(I.getInversePredicate(), Builder.CreateAnd(A, B), + Op1); + + // ~X < ~Y --> Y < X + // ~X < C --> X > ~C + if (match(Op0, m_Not(m_Value(A)))) { + if (match(Op1, m_Not(m_Value(B)))) + return new ICmpInst(I.getPredicate(), B, A); + + const APInt *C; + if (match(Op1, m_APInt(C))) + return new ICmpInst(I.getSwappedPredicate(), A, + ConstantInt::get(Op1->getType(), ~(*C))); + } + + Instruction *AddI = nullptr; + if (match(&I, m_UAddWithOverflow(m_Value(A), m_Value(B), + m_Instruction(AddI))) && + isa<IntegerType>(A->getType())) { + Value *Result; + Constant *Overflow; + // m_UAddWithOverflow can match patterns that do not include an explicit + // "add" instruction, so check the opcode of the matched op. + if (AddI->getOpcode() == Instruction::Add && + OptimizeOverflowCheck(Instruction::Add, /*Signed*/ false, A, B, *AddI, + Result, Overflow)) { + replaceInstUsesWith(*AddI, Result); + eraseInstFromFunction(*AddI); + return replaceInstUsesWith(I, Overflow); + } + } + + // (zext a) * (zext b) --> llvm.umul.with.overflow. + if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { + if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this)) + return R; + } + if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { + if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this)) + return R; + } + } + + if (Instruction *Res = foldICmpEquality(I)) + return Res; + + if (Instruction *Res = foldICmpOfUAddOv(I)) + return Res; + + // The 'cmpxchg' instruction returns an aggregate containing the old value and + // an i1 which indicates whether or not we successfully did the swap. + // + // Replace comparisons between the old value and the expected value with the + // indicator that 'cmpxchg' returns. + // + // N.B. This transform is only valid when the 'cmpxchg' is not permitted to + // spuriously fail. In those cases, the old value may equal the expected + // value but it is possible for the swap to not occur. + if (I.getPredicate() == ICmpInst::ICMP_EQ) + if (auto *EVI = dyn_cast<ExtractValueInst>(Op0)) + if (auto *ACXI = dyn_cast<AtomicCmpXchgInst>(EVI->getAggregateOperand())) + if (EVI->getIndices()[0] == 0 && ACXI->getCompareOperand() == Op1 && + !ACXI->isWeak()) + return ExtractValueInst::Create(ACXI, 1); + + { + Value *X; + const APInt *C; + // icmp X+Cst, X + if (match(Op0, m_Add(m_Value(X), m_APInt(C))) && Op1 == X) + return foldICmpAddOpConst(X, *C, I.getPredicate()); + + // icmp X, X+Cst + if (match(Op1, m_Add(m_Value(X), m_APInt(C))) && Op0 == X) + return foldICmpAddOpConst(X, *C, I.getSwappedPredicate()); + } + + if (Instruction *Res = foldICmpWithHighBitMask(I, Builder)) + return Res; + + if (I.getType()->isVectorTy()) + if (Instruction *Res = foldVectorCmp(I, Builder)) + return Res; + + if (Instruction *Res = foldICmpInvariantGroup(I)) + return Res; + + if (Instruction *Res = foldReductionIdiom(I, Builder, DL)) + return Res; + + return Changed ? &I : nullptr; +} + +/// Fold fcmp ([us]itofp x, cst) if possible. +Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, + Instruction *LHSI, + Constant *RHSC) { + if (!isa<ConstantFP>(RHSC)) return nullptr; + const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF(); + + // Get the width of the mantissa. We don't want to hack on conversions that + // might lose information from the integer, e.g. "i64 -> float" + int MantissaWidth = LHSI->getType()->getFPMantissaWidth(); + if (MantissaWidth == -1) return nullptr; // Unknown. + + IntegerType *IntTy = cast<IntegerType>(LHSI->getOperand(0)->getType()); + + bool LHSUnsigned = isa<UIToFPInst>(LHSI); + + if (I.isEquality()) { + FCmpInst::Predicate P = I.getPredicate(); + bool IsExact = false; + APSInt RHSCvt(IntTy->getBitWidth(), LHSUnsigned); + RHS.convertToInteger(RHSCvt, APFloat::rmNearestTiesToEven, &IsExact); + + // If the floating point constant isn't an integer value, we know if we will + // ever compare equal / not equal to it. + if (!IsExact) { + // TODO: Can never be -0.0 and other non-representable values + APFloat RHSRoundInt(RHS); + RHSRoundInt.roundToIntegral(APFloat::rmNearestTiesToEven); + if (RHS != RHSRoundInt) { + if (P == FCmpInst::FCMP_OEQ || P == FCmpInst::FCMP_UEQ) + return replaceInstUsesWith(I, Builder.getFalse()); + + assert(P == FCmpInst::FCMP_ONE || P == FCmpInst::FCMP_UNE); + return replaceInstUsesWith(I, Builder.getTrue()); + } + } + + // TODO: If the constant is exactly representable, is it always OK to do + // equality compares as integer? + } + + // Check to see that the input is converted from an integer type that is small + // enough that preserves all bits. TODO: check here for "known" sign bits. + // This would allow us to handle (fptosi (x >>s 62) to float) if x is i64 f.e. + unsigned InputSize = IntTy->getScalarSizeInBits(); + + // Following test does NOT adjust InputSize downwards for signed inputs, + // because the most negative value still requires all the mantissa bits + // to distinguish it from one less than that value. + if ((int)InputSize > MantissaWidth) { + // Conversion would lose accuracy. Check if loss can impact comparison. + int Exp = ilogb(RHS); + if (Exp == APFloat::IEK_Inf) { + int MaxExponent = ilogb(APFloat::getLargest(RHS.getSemantics())); + if (MaxExponent < (int)InputSize - !LHSUnsigned) + // Conversion could create infinity. + return nullptr; + } else { + // Note that if RHS is zero or NaN, then Exp is negative + // and first condition is trivially false. + if (MantissaWidth <= Exp && Exp <= (int)InputSize - !LHSUnsigned) + // Conversion could affect comparison. + return nullptr; + } + } + + // Otherwise, we can potentially simplify the comparison. We know that it + // will always come through as an integer value and we know the constant is + // not a NAN (it would have been previously simplified). + assert(!RHS.isNaN() && "NaN comparison not already folded!"); + + ICmpInst::Predicate Pred; + switch (I.getPredicate()) { + default: llvm_unreachable("Unexpected predicate!"); + case FCmpInst::FCMP_UEQ: + case FCmpInst::FCMP_OEQ: + Pred = ICmpInst::ICMP_EQ; + break; + case FCmpInst::FCMP_UGT: + case FCmpInst::FCMP_OGT: + Pred = LHSUnsigned ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_SGT; + break; + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OGE: + Pred = LHSUnsigned ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_SGE; + break; + case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_OLT: + Pred = LHSUnsigned ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_SLT; + break; + case FCmpInst::FCMP_ULE: + case FCmpInst::FCMP_OLE: + Pred = LHSUnsigned ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_SLE; + break; + case FCmpInst::FCMP_UNE: + case FCmpInst::FCMP_ONE: + Pred = ICmpInst::ICMP_NE; + break; + case FCmpInst::FCMP_ORD: + return replaceInstUsesWith(I, Builder.getTrue()); + case FCmpInst::FCMP_UNO: + return replaceInstUsesWith(I, Builder.getFalse()); + } + + // Now we know that the APFloat is a normal number, zero or inf. + + // See if the FP constant is too large for the integer. For example, + // comparing an i8 to 300.0. + unsigned IntWidth = IntTy->getScalarSizeInBits(); + + if (!LHSUnsigned) { + // If the RHS value is > SignedMax, fold the comparison. This handles +INF + // and large values. + APFloat SMax(RHS.getSemantics()); + SMax.convertFromAPInt(APInt::getSignedMaxValue(IntWidth), true, + APFloat::rmNearestTiesToEven); + if (SMax < RHS) { // smax < 13123.0 + if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SLT || + Pred == ICmpInst::ICMP_SLE) + return replaceInstUsesWith(I, Builder.getTrue()); + return replaceInstUsesWith(I, Builder.getFalse()); + } + } else { + // If the RHS value is > UnsignedMax, fold the comparison. This handles + // +INF and large values. + APFloat UMax(RHS.getSemantics()); + UMax.convertFromAPInt(APInt::getMaxValue(IntWidth), false, + APFloat::rmNearestTiesToEven); + if (UMax < RHS) { // umax < 13123.0 + if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_ULT || + Pred == ICmpInst::ICMP_ULE) + return replaceInstUsesWith(I, Builder.getTrue()); + return replaceInstUsesWith(I, Builder.getFalse()); + } + } + + if (!LHSUnsigned) { + // See if the RHS value is < SignedMin. + APFloat SMin(RHS.getSemantics()); + SMin.convertFromAPInt(APInt::getSignedMinValue(IntWidth), true, + APFloat::rmNearestTiesToEven); + if (SMin > RHS) { // smin > 12312.0 + if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT || + Pred == ICmpInst::ICMP_SGE) + return replaceInstUsesWith(I, Builder.getTrue()); + return replaceInstUsesWith(I, Builder.getFalse()); + } + } else { + // See if the RHS value is < UnsignedMin. + APFloat UMin(RHS.getSemantics()); + UMin.convertFromAPInt(APInt::getMinValue(IntWidth), false, + APFloat::rmNearestTiesToEven); + if (UMin > RHS) { // umin > 12312.0 + if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_UGT || + Pred == ICmpInst::ICMP_UGE) + return replaceInstUsesWith(I, Builder.getTrue()); + return replaceInstUsesWith(I, Builder.getFalse()); + } + } + + // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or + // [0, UMAX], but it may still be fractional. See if it is fractional by + // casting the FP value to the integer value and back, checking for equality. + // Don't do this for zero, because -0.0 is not fractional. + Constant *RHSInt = LHSUnsigned + ? ConstantExpr::getFPToUI(RHSC, IntTy) + : ConstantExpr::getFPToSI(RHSC, IntTy); + if (!RHS.isZero()) { + bool Equal = LHSUnsigned + ? ConstantExpr::getUIToFP(RHSInt, RHSC->getType()) == RHSC + : ConstantExpr::getSIToFP(RHSInt, RHSC->getType()) == RHSC; + if (!Equal) { + // If we had a comparison against a fractional value, we have to adjust + // the compare predicate and sometimes the value. RHSC is rounded towards + // zero at this point. + switch (Pred) { + default: llvm_unreachable("Unexpected integer comparison!"); + case ICmpInst::ICMP_NE: // (float)int != 4.4 --> true + return replaceInstUsesWith(I, Builder.getTrue()); + case ICmpInst::ICMP_EQ: // (float)int == 4.4 --> false + return replaceInstUsesWith(I, Builder.getFalse()); + case ICmpInst::ICMP_ULE: + // (float)int <= 4.4 --> int <= 4 + // (float)int <= -4.4 --> false + if (RHS.isNegative()) + return replaceInstUsesWith(I, Builder.getFalse()); + break; + case ICmpInst::ICMP_SLE: + // (float)int <= 4.4 --> int <= 4 + // (float)int <= -4.4 --> int < -4 + if (RHS.isNegative()) + Pred = ICmpInst::ICMP_SLT; + break; + case ICmpInst::ICMP_ULT: + // (float)int < -4.4 --> false + // (float)int < 4.4 --> int <= 4 + if (RHS.isNegative()) + return replaceInstUsesWith(I, Builder.getFalse()); + Pred = ICmpInst::ICMP_ULE; + break; + case ICmpInst::ICMP_SLT: + // (float)int < -4.4 --> int < -4 + // (float)int < 4.4 --> int <= 4 + if (!RHS.isNegative()) + Pred = ICmpInst::ICMP_SLE; + break; + case ICmpInst::ICMP_UGT: + // (float)int > 4.4 --> int > 4 + // (float)int > -4.4 --> true + if (RHS.isNegative()) + return replaceInstUsesWith(I, Builder.getTrue()); + break; + case ICmpInst::ICMP_SGT: + // (float)int > 4.4 --> int > 4 + // (float)int > -4.4 --> int >= -4 + if (RHS.isNegative()) + Pred = ICmpInst::ICMP_SGE; + break; + case ICmpInst::ICMP_UGE: + // (float)int >= -4.4 --> true + // (float)int >= 4.4 --> int > 4 + if (RHS.isNegative()) + return replaceInstUsesWith(I, Builder.getTrue()); + Pred = ICmpInst::ICMP_UGT; + break; + case ICmpInst::ICMP_SGE: + // (float)int >= -4.4 --> int >= -4 + // (float)int >= 4.4 --> int > 4 + if (!RHS.isNegative()) + Pred = ICmpInst::ICMP_SGT; + break; + } + } + } + + // Lower this FP comparison into an appropriate integer version of the + // comparison. + return new ICmpInst(Pred, LHSI->getOperand(0), RHSInt); +} + +/// Fold (C / X) < 0.0 --> X < 0.0 if possible. Swap predicate if necessary. +static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI, + Constant *RHSC) { + // When C is not 0.0 and infinities are not allowed: + // (C / X) < 0.0 is a sign-bit test of X + // (C / X) < 0.0 --> X < 0.0 (if C is positive) + // (C / X) < 0.0 --> X > 0.0 (if C is negative, swap the predicate) + // + // Proof: + // Multiply (C / X) < 0.0 by X * X / C. + // - X is non zero, if it is the flag 'ninf' is violated. + // - C defines the sign of X * X * C. Thus it also defines whether to swap + // the predicate. C is also non zero by definition. + // + // Thus X * X / C is non zero and the transformation is valid. [qed] + + FCmpInst::Predicate Pred = I.getPredicate(); + + // Check that predicates are valid. + if ((Pred != FCmpInst::FCMP_OGT) && (Pred != FCmpInst::FCMP_OLT) && + (Pred != FCmpInst::FCMP_OGE) && (Pred != FCmpInst::FCMP_OLE)) + return nullptr; + + // Check that RHS operand is zero. + if (!match(RHSC, m_AnyZeroFP())) + return nullptr; + + // Check fastmath flags ('ninf'). + if (!LHSI->hasNoInfs() || !I.hasNoInfs()) + return nullptr; + + // Check the properties of the dividend. It must not be zero to avoid a + // division by zero (see Proof). + const APFloat *C; + if (!match(LHSI->getOperand(0), m_APFloat(C))) + return nullptr; + + if (C->isZero()) + return nullptr; + + // Get swapped predicate if necessary. + if (C->isNegative()) + Pred = I.getSwappedPredicate(); + + return new FCmpInst(Pred, LHSI->getOperand(1), RHSC, "", &I); +} + +/// Optimize fabs(X) compared with zero. +static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) { + Value *X; + if (!match(I.getOperand(0), m_FAbs(m_Value(X))) || + !match(I.getOperand(1), m_PosZeroFP())) + return nullptr; + + auto replacePredAndOp0 = [&IC](FCmpInst *I, FCmpInst::Predicate P, Value *X) { + I->setPredicate(P); + return IC.replaceOperand(*I, 0, X); + }; + + switch (I.getPredicate()) { + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OLT: + // fabs(X) >= 0.0 --> true + // fabs(X) < 0.0 --> false + llvm_unreachable("fcmp should have simplified"); + + case FCmpInst::FCMP_OGT: + // fabs(X) > 0.0 --> X != 0.0 + return replacePredAndOp0(&I, FCmpInst::FCMP_ONE, X); + + case FCmpInst::FCMP_UGT: + // fabs(X) u> 0.0 --> X u!= 0.0 + return replacePredAndOp0(&I, FCmpInst::FCMP_UNE, X); + + case FCmpInst::FCMP_OLE: + // fabs(X) <= 0.0 --> X == 0.0 + return replacePredAndOp0(&I, FCmpInst::FCMP_OEQ, X); + + case FCmpInst::FCMP_ULE: + // fabs(X) u<= 0.0 --> X u== 0.0 + return replacePredAndOp0(&I, FCmpInst::FCMP_UEQ, X); + + case FCmpInst::FCMP_OGE: + // fabs(X) >= 0.0 --> !isnan(X) + assert(!I.hasNoNaNs() && "fcmp should have simplified"); + return replacePredAndOp0(&I, FCmpInst::FCMP_ORD, X); + + case FCmpInst::FCMP_ULT: + // fabs(X) u< 0.0 --> isnan(X) + assert(!I.hasNoNaNs() && "fcmp should have simplified"); + return replacePredAndOp0(&I, FCmpInst::FCMP_UNO, X); + + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_UEQ: + case FCmpInst::FCMP_ONE: + case FCmpInst::FCMP_UNE: + case FCmpInst::FCMP_ORD: + case FCmpInst::FCMP_UNO: + // Look through the fabs() because it doesn't change anything but the sign. + // fabs(X) == 0.0 --> X == 0.0, + // fabs(X) != 0.0 --> X != 0.0 + // isnan(fabs(X)) --> isnan(X) + // !isnan(fabs(X) --> !isnan(X) + return replacePredAndOp0(&I, I.getPredicate(), X); + + default: + return nullptr; + } +} + +static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) { + CmpInst::Predicate Pred = I.getPredicate(); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Canonicalize fneg as Op1. + if (match(Op0, m_FNeg(m_Value())) && !match(Op1, m_FNeg(m_Value()))) { + std::swap(Op0, Op1); + Pred = I.getSwappedPredicate(); + } + + if (!match(Op1, m_FNeg(m_Specific(Op0)))) + return nullptr; + + // Replace the negated operand with 0.0: + // fcmp Pred Op0, -Op0 --> fcmp Pred Op0, 0.0 + Constant *Zero = ConstantFP::getNullValue(Op0->getType()); + return new FCmpInst(Pred, Op0, Zero, "", &I); +} + +Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { + bool Changed = false; + + /// Orders the operands of the compare so that they are listed from most + /// complex to least complex. This puts constants before unary operators, + /// before binary operators. + if (getComplexity(I.getOperand(0)) < getComplexity(I.getOperand(1))) { + I.swapOperands(); + Changed = true; + } + + const CmpInst::Predicate Pred = I.getPredicate(); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (Value *V = simplifyFCmpInst(Pred, Op0, Op1, I.getFastMathFlags(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + // Simplify 'fcmp pred X, X' + Type *OpType = Op0->getType(); + assert(OpType == Op1->getType() && "fcmp with different-typed operands?"); + if (Op0 == Op1) { + switch (Pred) { + default: break; + case FCmpInst::FCMP_UNO: // True if unordered: isnan(X) | isnan(Y) + case FCmpInst::FCMP_ULT: // True if unordered or less than + case FCmpInst::FCMP_UGT: // True if unordered or greater than + case FCmpInst::FCMP_UNE: // True if unordered or not equal + // Canonicalize these to be 'fcmp uno %X, 0.0'. + I.setPredicate(FCmpInst::FCMP_UNO); + I.setOperand(1, Constant::getNullValue(OpType)); + return &I; + + case FCmpInst::FCMP_ORD: // True if ordered (no nans) + case FCmpInst::FCMP_OEQ: // True if ordered and equal + case FCmpInst::FCMP_OGE: // True if ordered and greater than or equal + case FCmpInst::FCMP_OLE: // True if ordered and less than or equal + // Canonicalize these to be 'fcmp ord %X, 0.0'. + I.setPredicate(FCmpInst::FCMP_ORD); + I.setOperand(1, Constant::getNullValue(OpType)); + return &I; + } + } + + // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, + // then canonicalize the operand to 0.0. + if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { + if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI)) + return replaceOperand(I, 0, ConstantFP::getNullValue(OpType)); + + if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI)) + return replaceOperand(I, 1, ConstantFP::getNullValue(OpType)); + } + + // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y + Value *X, *Y; + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) + return new FCmpInst(I.getSwappedPredicate(), X, Y, "", &I); + + if (Instruction *R = foldFCmpFNegCommonOp(I)) + return R; + + // Test if the FCmpInst instruction is used exclusively by a select as + // part of a minimum or maximum operation. If so, refrain from doing + // any other folding. This helps out other analyses which understand + // non-obfuscated minimum and maximum idioms, such as ScalarEvolution + // and CodeGen. And in this case, at least one of the comparison + // operands has at least one user besides the compare (the select), + // which would often largely negate the benefit of folding anyway. + if (I.hasOneUse()) + if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) { + Value *A, *B; + SelectPatternResult SPR = matchSelectPattern(SI, A, B); + if (SPR.Flavor != SPF_UNKNOWN) + return nullptr; + } + + // The sign of 0.0 is ignored by fcmp, so canonicalize to +0.0: + // fcmp Pred X, -0.0 --> fcmp Pred X, 0.0 + if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) + return replaceOperand(I, 1, ConstantFP::getNullValue(OpType)); + + // Handle fcmp with instruction LHS and constant RHS. + Instruction *LHSI; + Constant *RHSC; + if (match(Op0, m_Instruction(LHSI)) && match(Op1, m_Constant(RHSC))) { + switch (LHSI->getOpcode()) { + case Instruction::PHI: + // Only fold fcmp into the PHI if the phi and fcmp are in the same + // block. If in the same block, we're encouraging jump threading. If + // not, we are just pessimizing the code by making an i1 phi. + if (LHSI->getParent() == I.getParent()) + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) + return NV; + break; + case Instruction::SIToFP: + case Instruction::UIToFP: + if (Instruction *NV = foldFCmpIntToFPConst(I, LHSI, RHSC)) + return NV; + break; + case Instruction::FDiv: + if (Instruction *NV = foldFCmpReciprocalAndZero(I, LHSI, RHSC)) + return NV; + break; + case Instruction::Load: + if (auto *GEP = dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) + if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) + if (Instruction *Res = foldCmpLoadFromIndexedGlobal( + cast<LoadInst>(LHSI), GEP, GV, I)) + return Res; + break; + } + } + + if (Instruction *R = foldFabsWithFcmpZero(I, *this)) + return R; + + if (match(Op0, m_FNeg(m_Value(X)))) { + // fcmp pred (fneg X), C --> fcmp swap(pred) X, -C + Constant *C; + if (match(Op1, m_Constant(C))) { + Constant *NegC = ConstantExpr::getFNeg(C); + return new FCmpInst(I.getSwappedPredicate(), X, NegC, "", &I); + } + } + + if (match(Op0, m_FPExt(m_Value(X)))) { + // fcmp (fpext X), (fpext Y) -> fcmp X, Y + if (match(Op1, m_FPExt(m_Value(Y))) && X->getType() == Y->getType()) + return new FCmpInst(Pred, X, Y, "", &I); + + const APFloat *C; + if (match(Op1, m_APFloat(C))) { + const fltSemantics &FPSem = + X->getType()->getScalarType()->getFltSemantics(); + bool Lossy; + APFloat TruncC = *C; + TruncC.convert(FPSem, APFloat::rmNearestTiesToEven, &Lossy); + + if (Lossy) { + // X can't possibly equal the higher-precision constant, so reduce any + // equality comparison. + // TODO: Other predicates can be handled via getFCmpCode(). + switch (Pred) { + case FCmpInst::FCMP_OEQ: + // X is ordered and equal to an impossible constant --> false + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + case FCmpInst::FCMP_ONE: + // X is ordered and not equal to an impossible constant --> ordered + return new FCmpInst(FCmpInst::FCMP_ORD, X, + ConstantFP::getNullValue(X->getType())); + case FCmpInst::FCMP_UEQ: + // X is unordered or equal to an impossible constant --> unordered + return new FCmpInst(FCmpInst::FCMP_UNO, X, + ConstantFP::getNullValue(X->getType())); + case FCmpInst::FCMP_UNE: + // X is unordered or not equal to an impossible constant --> true + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + default: + break; + } + } + + // fcmp (fpext X), C -> fcmp X, (fptrunc C) if fptrunc is lossless + // Avoid lossy conversions and denormals. + // Zero is a special case that's OK to convert. + APFloat Fabs = TruncC; + Fabs.clearSign(); + if (!Lossy && + (!(Fabs < APFloat::getSmallestNormalized(FPSem)) || Fabs.isZero())) { + Constant *NewC = ConstantFP::get(X->getType(), TruncC); + return new FCmpInst(Pred, X, NewC, "", &I); + } + } + } + + // Convert a sign-bit test of an FP value into a cast and integer compare. + // TODO: Simplify if the copysign constant is 0.0 or NaN. + // TODO: Handle non-zero compare constants. + // TODO: Handle other predicates. + const APFloat *C; + if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::copysign>(m_APFloat(C), + m_Value(X)))) && + match(Op1, m_AnyZeroFP()) && !C->isZero() && !C->isNaN()) { + Type *IntType = Builder.getIntNTy(X->getType()->getScalarSizeInBits()); + if (auto *VecTy = dyn_cast<VectorType>(OpType)) + IntType = VectorType::get(IntType, VecTy->getElementCount()); + + // copysign(non-zero constant, X) < 0.0 --> (bitcast X) < 0 + if (Pred == FCmpInst::FCMP_OLT) { + Value *IntX = Builder.CreateBitCast(X, IntType); + return new ICmpInst(ICmpInst::ICMP_SLT, IntX, + ConstantInt::getNullValue(IntType)); + } + } + + if (I.getType()->isVectorTy()) + if (Instruction *Res = foldVectorCmp(I, Builder)) + return Res; + + return Changed ? &I : nullptr; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h new file mode 100644 index 000000000000..271154bb3f5a --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -0,0 +1,818 @@ +//===- InstCombineInternal.h - InstCombine pass internals -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// +/// This file provides internal interfaces used to implement the InstCombine. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H +#define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H + +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/TargetFolder.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +#include "llvm/Transforms/Utils/Local.h" +#include <cassert> + +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + +using namespace llvm::PatternMatch; + +// As a default, let's assume that we want to be aggressive, +// and attempt to traverse with no limits in attempt to sink negation. +static constexpr unsigned NegatorDefaultMaxDepth = ~0U; + +// Let's guesstimate that most often we will end up visiting/producing +// fairly small number of new instructions. +static constexpr unsigned NegatorMaxNodesSSO = 16; + +namespace llvm { + +class AAResults; +class APInt; +class AssumptionCache; +class BlockFrequencyInfo; +class DataLayout; +class DominatorTree; +class GEPOperator; +class GlobalVariable; +class LoopInfo; +class OptimizationRemarkEmitter; +class ProfileSummaryInfo; +class TargetLibraryInfo; +class User; + +class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final + : public InstCombiner, + public InstVisitor<InstCombinerImpl, Instruction *> { +public: + InstCombinerImpl(InstructionWorklist &Worklist, BuilderTy &Builder, + bool MinimizeSize, AAResults *AA, AssumptionCache &AC, + TargetLibraryInfo &TLI, TargetTransformInfo &TTI, + DominatorTree &DT, OptimizationRemarkEmitter &ORE, + BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, + const DataLayout &DL, LoopInfo *LI) + : InstCombiner(Worklist, Builder, MinimizeSize, AA, AC, TLI, TTI, DT, ORE, + BFI, PSI, DL, LI) {} + + virtual ~InstCombinerImpl() = default; + + /// Run the combiner over the entire worklist until it is empty. + /// + /// \returns true if the IR is changed. + bool run(); + + // Visitation implementation - Implement instruction combining for different + // instruction types. The semantics are as follows: + // Return Value: + // null - No change was made + // I - Change was made, I is still valid, I may be dead though + // otherwise - Change was made, replace I with returned instruction + // + Instruction *visitFNeg(UnaryOperator &I); + Instruction *visitAdd(BinaryOperator &I); + Instruction *visitFAdd(BinaryOperator &I); + Value *OptimizePointerDifference( + Value *LHS, Value *RHS, Type *Ty, bool isNUW); + Instruction *visitSub(BinaryOperator &I); + Instruction *visitFSub(BinaryOperator &I); + Instruction *visitMul(BinaryOperator &I); + Instruction *visitFMul(BinaryOperator &I); + Instruction *visitURem(BinaryOperator &I); + Instruction *visitSRem(BinaryOperator &I); + Instruction *visitFRem(BinaryOperator &I); + bool simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I); + Instruction *commonIRemTransforms(BinaryOperator &I); + Instruction *commonIDivTransforms(BinaryOperator &I); + Instruction *visitUDiv(BinaryOperator &I); + Instruction *visitSDiv(BinaryOperator &I); + Instruction *visitFDiv(BinaryOperator &I); + Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted); + Instruction *visitAnd(BinaryOperator &I); + Instruction *visitOr(BinaryOperator &I); + bool sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I); + Instruction *visitXor(BinaryOperator &I); + Instruction *visitShl(BinaryOperator &I); + Value *reassociateShiftAmtsOfTwoSameDirectionShifts( + BinaryOperator *Sh0, const SimplifyQuery &SQ, + bool AnalyzeForSignBitExtraction = false); + Instruction *canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( + BinaryOperator &I); + Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract( + BinaryOperator &OldAShr); + Instruction *visitAShr(BinaryOperator &I); + Instruction *visitLShr(BinaryOperator &I); + Instruction *commonShiftTransforms(BinaryOperator &I); + Instruction *visitFCmpInst(FCmpInst &I); + CmpInst *canonicalizeICmpPredicate(CmpInst &I); + Instruction *visitICmpInst(ICmpInst &I); + Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1, + BinaryOperator &I); + Instruction *commonCastTransforms(CastInst &CI); + Instruction *commonPointerCastTransforms(CastInst &CI); + Instruction *visitTrunc(TruncInst &CI); + Instruction *visitZExt(ZExtInst &CI); + Instruction *visitSExt(SExtInst &CI); + Instruction *visitFPTrunc(FPTruncInst &CI); + Instruction *visitFPExt(CastInst &CI); + Instruction *visitFPToUI(FPToUIInst &FI); + Instruction *visitFPToSI(FPToSIInst &FI); + Instruction *visitUIToFP(CastInst &CI); + Instruction *visitSIToFP(CastInst &CI); + Instruction *visitPtrToInt(PtrToIntInst &CI); + Instruction *visitIntToPtr(IntToPtrInst &CI); + Instruction *visitBitCast(BitCastInst &CI); + Instruction *visitAddrSpaceCast(AddrSpaceCastInst &CI); + Instruction *foldItoFPtoI(CastInst &FI); + Instruction *visitSelectInst(SelectInst &SI); + Instruction *visitCallInst(CallInst &CI); + Instruction *visitInvokeInst(InvokeInst &II); + Instruction *visitCallBrInst(CallBrInst &CBI); + + Instruction *SliceUpIllegalIntegerPHI(PHINode &PN); + Instruction *visitPHINode(PHINode &PN); + Instruction *visitGetElementPtrInst(GetElementPtrInst &GEP); + Instruction *visitGEPOfGEP(GetElementPtrInst &GEP, GEPOperator *Src); + Instruction *visitGEPOfBitcast(BitCastInst *BCI, GetElementPtrInst &GEP); + Instruction *visitAllocaInst(AllocaInst &AI); + Instruction *visitAllocSite(Instruction &FI); + Instruction *visitFree(CallInst &FI); + Instruction *visitLoadInst(LoadInst &LI); + Instruction *visitStoreInst(StoreInst &SI); + Instruction *visitAtomicRMWInst(AtomicRMWInst &SI); + Instruction *visitUnconditionalBranchInst(BranchInst &BI); + Instruction *visitBranchInst(BranchInst &BI); + Instruction *visitFenceInst(FenceInst &FI); + Instruction *visitSwitchInst(SwitchInst &SI); + Instruction *visitReturnInst(ReturnInst &RI); + Instruction *visitUnreachableInst(UnreachableInst &I); + Instruction * + foldAggregateConstructionIntoAggregateReuse(InsertValueInst &OrigIVI); + Instruction *visitInsertValueInst(InsertValueInst &IV); + Instruction *visitInsertElementInst(InsertElementInst &IE); + Instruction *visitExtractElementInst(ExtractElementInst &EI); + Instruction *visitShuffleVectorInst(ShuffleVectorInst &SVI); + Instruction *visitExtractValueInst(ExtractValueInst &EV); + Instruction *visitLandingPadInst(LandingPadInst &LI); + Instruction *visitVAEndInst(VAEndInst &I); + Value *pushFreezeToPreventPoisonFromPropagating(FreezeInst &FI); + bool freezeOtherUses(FreezeInst &FI); + Instruction *foldFreezeIntoRecurrence(FreezeInst &I, PHINode *PN); + Instruction *visitFreeze(FreezeInst &I); + + /// Specify what to return for unhandled instructions. + Instruction *visitInstruction(Instruction &I) { return nullptr; } + + /// True when DB dominates all uses of DI except UI. + /// UI must be in the same block as DI. + /// The routine checks that the DI parent and DB are different. + bool dominatesAllUses(const Instruction *DI, const Instruction *UI, + const BasicBlock *DB) const; + + /// Try to replace select with select operand SIOpd in SI-ICmp sequence. + bool replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp, + const unsigned SIOpd); + + LoadInst *combineLoadToNewType(LoadInst &LI, Type *NewTy, + const Twine &Suffix = ""); + +private: + bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI); + bool isDesirableIntType(unsigned BitWidth) const; + bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; + bool shouldChangeType(Type *From, Type *To) const; + Value *dyn_castNegVal(Value *V) const; + + /// Classify whether a cast is worth optimizing. + /// + /// This is a helper to decide whether the simplification of + /// logic(cast(A), cast(B)) to cast(logic(A, B)) should be performed. + /// + /// \param CI The cast we are interested in. + /// + /// \return true if this cast actually results in any code being generated and + /// if it cannot already be eliminated by some other transformation. + bool shouldOptimizeCast(CastInst *CI); + + /// Try to optimize a sequence of instructions checking if an operation + /// on LHS and RHS overflows. + /// + /// If this overflow check is done via one of the overflow check intrinsics, + /// then CtxI has to be the call instruction calling that intrinsic. If this + /// overflow check is done by arithmetic followed by a compare, then CtxI has + /// to be the arithmetic instruction. + /// + /// If a simplification is possible, stores the simplified result of the + /// operation in OperationResult and result of the overflow check in + /// OverflowResult, and return true. If no simplification is possible, + /// returns false. + bool OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp, bool IsSigned, + Value *LHS, Value *RHS, + Instruction &CtxI, Value *&OperationResult, + Constant *&OverflowResult); + + Instruction *visitCallBase(CallBase &Call); + Instruction *tryOptimizeCall(CallInst *CI); + bool transformConstExprCastCall(CallBase &Call); + Instruction *transformCallThroughTrampoline(CallBase &Call, + IntrinsicInst &Tramp); + + Value *simplifyMaskedLoad(IntrinsicInst &II); + Instruction *simplifyMaskedStore(IntrinsicInst &II); + Instruction *simplifyMaskedGather(IntrinsicInst &II); + Instruction *simplifyMaskedScatter(IntrinsicInst &II); + + /// Transform (zext icmp) to bitwise / integer operations in order to + /// eliminate it. + /// + /// \param ICI The icmp of the (zext icmp) pair we are interested in. + /// \parem CI The zext of the (zext icmp) pair we are interested in. + /// + /// \return null if the transformation cannot be performed. If the + /// transformation can be performed the new instruction that replaces the + /// (zext icmp) pair will be returned. + Instruction *transformZExtICmp(ICmpInst *ICI, ZExtInst &CI); + + Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); + + bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS, + const Instruction &CxtI) const { + return computeOverflowForSignedAdd(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + } + + bool willNotOverflowUnsignedAdd(const Value *LHS, const Value *RHS, + const Instruction &CxtI) const { + return computeOverflowForUnsignedAdd(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + } + + bool willNotOverflowAdd(const Value *LHS, const Value *RHS, + const Instruction &CxtI, bool IsSigned) const { + return IsSigned ? willNotOverflowSignedAdd(LHS, RHS, CxtI) + : willNotOverflowUnsignedAdd(LHS, RHS, CxtI); + } + + bool willNotOverflowSignedSub(const Value *LHS, const Value *RHS, + const Instruction &CxtI) const { + return computeOverflowForSignedSub(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + } + + bool willNotOverflowUnsignedSub(const Value *LHS, const Value *RHS, + const Instruction &CxtI) const { + return computeOverflowForUnsignedSub(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + } + + bool willNotOverflowSub(const Value *LHS, const Value *RHS, + const Instruction &CxtI, bool IsSigned) const { + return IsSigned ? willNotOverflowSignedSub(LHS, RHS, CxtI) + : willNotOverflowUnsignedSub(LHS, RHS, CxtI); + } + + bool willNotOverflowSignedMul(const Value *LHS, const Value *RHS, + const Instruction &CxtI) const { + return computeOverflowForSignedMul(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + } + + bool willNotOverflowUnsignedMul(const Value *LHS, const Value *RHS, + const Instruction &CxtI) const { + return computeOverflowForUnsignedMul(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + } + + bool willNotOverflowMul(const Value *LHS, const Value *RHS, + const Instruction &CxtI, bool IsSigned) const { + return IsSigned ? willNotOverflowSignedMul(LHS, RHS, CxtI) + : willNotOverflowUnsignedMul(LHS, RHS, CxtI); + } + + bool willNotOverflow(BinaryOperator::BinaryOps Opcode, const Value *LHS, + const Value *RHS, const Instruction &CxtI, + bool IsSigned) const { + switch (Opcode) { + case Instruction::Add: return willNotOverflowAdd(LHS, RHS, CxtI, IsSigned); + case Instruction::Sub: return willNotOverflowSub(LHS, RHS, CxtI, IsSigned); + case Instruction::Mul: return willNotOverflowMul(LHS, RHS, CxtI, IsSigned); + default: llvm_unreachable("Unexpected opcode for overflow query"); + } + } + + Value *EmitGEPOffset(User *GEP); + Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); + Instruction *foldBitcastExtElt(ExtractElementInst &ExtElt); + Instruction *foldCastedBitwiseLogic(BinaryOperator &I); + Instruction *foldBinopOfSextBoolToSelect(BinaryOperator &I); + Instruction *narrowBinOp(TruncInst &Trunc); + Instruction *narrowMaskedBinOp(BinaryOperator &And); + Instruction *narrowMathIfNoOverflow(BinaryOperator &I); + Instruction *narrowFunnelShift(TruncInst &Trunc); + Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); + Instruction *matchSAddSubSat(IntrinsicInst &MinMax1); + Instruction *foldNot(BinaryOperator &I); + + void freelyInvertAllUsersOf(Value *V); + + /// Determine if a pair of casts can be replaced by a single cast. + /// + /// \param CI1 The first of a pair of casts. + /// \param CI2 The second of a pair of casts. + /// + /// \return 0 if the cast pair cannot be eliminated, otherwise returns an + /// Instruction::CastOps value for a cast that can replace the pair, casting + /// CI1->getSrcTy() to CI2->getDstTy(). + /// + /// \see CastInst::isEliminableCastPair + Instruction::CastOps isEliminableCastPair(const CastInst *CI1, + const CastInst *CI2); + Value *simplifyIntToPtrRoundTripCast(Value *Val); + + Value *foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &I, + bool IsAnd, bool IsLogical = false); + Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Xor); + + Value *foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd); + + Value *foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, ICmpInst *ICmp2, + bool IsAnd); + + /// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp). + /// NOTE: Unlike most of instcombine, this returns a Value which should + /// already be inserted into the function. + Value *foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd, + bool IsLogicalSelect = false); + + Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, + Instruction *CxtI, bool IsAnd, + bool IsLogical = false); + Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D); + Value *getSelectCondition(Value *A, Value *B); + + Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II); + Instruction *foldFPSignBitOps(BinaryOperator &I); + + // Optimize one of these forms: + // and i1 Op, SI / select i1 Op, i1 SI, i1 false (if IsAnd = true) + // or i1 Op, SI / select i1 Op, i1 true, i1 SI (if IsAnd = false) + // into simplier select instruction using isImpliedCondition. + Instruction *foldAndOrOfSelectUsingImpliedCond(Value *Op, SelectInst &SI, + bool IsAnd); + +public: + /// Inserts an instruction \p New before instruction \p Old + /// + /// Also adds the new instruction to the worklist and returns \p New so that + /// it is suitable for use as the return from the visitation patterns. + Instruction *InsertNewInstBefore(Instruction *New, Instruction &Old) { + assert(New && !New->getParent() && + "New instruction already inserted into a basic block!"); + BasicBlock *BB = Old.getParent(); + BB->getInstList().insert(Old.getIterator(), New); // Insert inst + Worklist.add(New); + return New; + } + + /// Same as InsertNewInstBefore, but also sets the debug loc. + Instruction *InsertNewInstWith(Instruction *New, Instruction &Old) { + New->setDebugLoc(Old.getDebugLoc()); + return InsertNewInstBefore(New, Old); + } + + /// A combiner-aware RAUW-like routine. + /// + /// This method is to be used when an instruction is found to be dead, + /// replaceable with another preexisting expression. Here we add all uses of + /// I to the worklist, replace all uses of I with the new value, then return + /// I, so that the inst combiner will know that I was modified. + Instruction *replaceInstUsesWith(Instruction &I, Value *V) { + // If there are no uses to replace, then we return nullptr to indicate that + // no changes were made to the program. + if (I.use_empty()) return nullptr; + + Worklist.pushUsersToWorkList(I); // Add all modified instrs to worklist. + + // If we are replacing the instruction with itself, this must be in a + // segment of unreachable code, so just clobber the instruction. + if (&I == V) + V = PoisonValue::get(I.getType()); + + LLVM_DEBUG(dbgs() << "IC: Replacing " << I << "\n" + << " with " << *V << '\n'); + + I.replaceAllUsesWith(V); + MadeIRChange = true; + return &I; + } + + /// Replace operand of instruction and add old operand to the worklist. + Instruction *replaceOperand(Instruction &I, unsigned OpNum, Value *V) { + Worklist.addValue(I.getOperand(OpNum)); + I.setOperand(OpNum, V); + return &I; + } + + /// Replace use and add the previously used value to the worklist. + void replaceUse(Use &U, Value *NewValue) { + Worklist.addValue(U); + U = NewValue; + } + + /// Create and insert the idiom we use to indicate a block is unreachable + /// without having to rewrite the CFG from within InstCombine. + void CreateNonTerminatorUnreachable(Instruction *InsertAt) { + auto &Ctx = InsertAt->getContext(); + new StoreInst(ConstantInt::getTrue(Ctx), + PoisonValue::get(Type::getInt1PtrTy(Ctx)), + InsertAt); + } + + + /// Combiner aware instruction erasure. + /// + /// When dealing with an instruction that has side effects or produces a void + /// value, we can't rely on DCE to delete the instruction. Instead, visit + /// methods should return the value returned by this function. + Instruction *eraseInstFromFunction(Instruction &I) override { + LLVM_DEBUG(dbgs() << "IC: ERASE " << I << '\n'); + assert(I.use_empty() && "Cannot erase instruction that is used!"); + salvageDebugInfo(I); + + // Make sure that we reprocess all operands now that we reduced their + // use counts. + for (Use &Operand : I.operands()) + if (auto *Inst = dyn_cast<Instruction>(Operand)) + Worklist.add(Inst); + + Worklist.remove(&I); + I.eraseFromParent(); + MadeIRChange = true; + return nullptr; // Don't do anything with FI + } + + void computeKnownBits(const Value *V, KnownBits &Known, + unsigned Depth, const Instruction *CxtI) const { + llvm::computeKnownBits(V, Known, DL, Depth, &AC, CxtI, &DT); + } + + KnownBits computeKnownBits(const Value *V, unsigned Depth, + const Instruction *CxtI) const { + return llvm::computeKnownBits(V, DL, Depth, &AC, CxtI, &DT); + } + + bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero = false, + unsigned Depth = 0, + const Instruction *CxtI = nullptr) { + return llvm::isKnownToBeAPowerOfTwo(V, DL, OrZero, Depth, &AC, CxtI, &DT); + } + + bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth = 0, + const Instruction *CxtI = nullptr) const { + return llvm::MaskedValueIsZero(V, Mask, DL, Depth, &AC, CxtI, &DT); + } + + unsigned ComputeNumSignBits(const Value *Op, unsigned Depth = 0, + const Instruction *CxtI = nullptr) const { + return llvm::ComputeNumSignBits(Op, DL, Depth, &AC, CxtI, &DT); + } + + OverflowResult computeOverflowForUnsignedMul(const Value *LHS, + const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT); + } + + OverflowResult computeOverflowForSignedMul(const Value *LHS, + const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForSignedMul(LHS, RHS, DL, &AC, CxtI, &DT); + } + + OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, + const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); + } + + OverflowResult computeOverflowForSignedAdd(const Value *LHS, + const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForSignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); + } + + OverflowResult computeOverflowForUnsignedSub(const Value *LHS, + const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForUnsignedSub(LHS, RHS, DL, &AC, CxtI, &DT); + } + + OverflowResult computeOverflowForSignedSub(const Value *LHS, const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForSignedSub(LHS, RHS, DL, &AC, CxtI, &DT); + } + + OverflowResult computeOverflow( + Instruction::BinaryOps BinaryOp, bool IsSigned, + Value *LHS, Value *RHS, Instruction *CxtI) const; + + /// Performs a few simplifications for operators which are associative + /// or commutative. + bool SimplifyAssociativeOrCommutative(BinaryOperator &I); + + /// Tries to simplify binary operations which some other binary + /// operation distributes over. + /// + /// It does this by either by factorizing out common terms (eg "(A*B)+(A*C)" + /// -> "A*(B+C)") or expanding out if this results in simplifications (eg: "A + /// & (B | C) -> (A&B) | (A&C)" if this is a win). Returns the simplified + /// value, or null if it didn't simplify. + Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + + /// Tries to simplify add operations using the definition of remainder. + /// + /// The definition of remainder is X % C = X - (X / C ) * C. The add + /// expression X % C0 + (( X / C0 ) % C1) * C0 can be simplified to + /// X % (C0 * C1) + Value *SimplifyAddWithRemainder(BinaryOperator &I); + + // Binary Op helper for select operations where the expression can be + // efficiently reorganized. + Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS, + Value *RHS); + + /// This tries to simplify binary operations by factorizing out common terms + /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). + Value *tryFactorization(BinaryOperator &, Instruction::BinaryOps, Value *, + Value *, Value *, Value *); + + /// Match a select chain which produces one of three values based on whether + /// the LHS is less than, equal to, or greater than RHS respectively. + /// Return true if we matched a three way compare idiom. The LHS, RHS, Less, + /// Equal and Greater values are saved in the matching process and returned to + /// the caller. + bool matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, Value *&RHS, + ConstantInt *&Less, ConstantInt *&Equal, + ConstantInt *&Greater); + + /// Attempts to replace V with a simpler value based on the demanded + /// bits. + Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBits &Known, + unsigned Depth, Instruction *CxtI); + bool SimplifyDemandedBits(Instruction *I, unsigned Op, + const APInt &DemandedMask, KnownBits &Known, + unsigned Depth = 0) override; + + /// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne + /// bits. It also tries to handle simplifications that can be done based on + /// DemandedMask, but without modifying the Instruction. + Value *SimplifyMultipleUseDemandedBits(Instruction *I, + const APInt &DemandedMask, + KnownBits &Known, + unsigned Depth, Instruction *CxtI); + + /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded + /// bit for "r1 = shr x, c1; r2 = shl r1, c2" instruction sequence. + Value *simplifyShrShlDemandedBits( + Instruction *Shr, const APInt &ShrOp1, Instruction *Shl, + const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known); + + /// Tries to simplify operands to an integer instruction based on its + /// demanded bits. + bool SimplifyDemandedInstructionBits(Instruction &Inst); + + virtual Value * + SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, + unsigned Depth = 0, + bool AllowMultipleUsers = false) override; + + /// Canonicalize the position of binops relative to shufflevector. + Instruction *foldVectorBinop(BinaryOperator &Inst); + Instruction *foldVectorSelect(SelectInst &Sel); + Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf); + + /// Given a binary operator, cast instruction, or select which has a PHI node + /// as operand #0, see if we can fold the instruction into the PHI (which is + /// only possible if all operands to the PHI are constants). + Instruction *foldOpIntoPhi(Instruction &I, PHINode *PN); + + /// For a binary operator with 2 phi operands, try to hoist the binary + /// operation before the phi. This can result in fewer instructions in + /// patterns where at least one set of phi operands simplifies. + /// Example: + /// BB3: binop (phi [X, BB1], [C1, BB2]), (phi [Y, BB1], [C2, BB2]) + /// --> + /// BB1: BO = binop X, Y + /// BB3: phi [BO, BB1], [(binop C1, C2), BB2] + Instruction *foldBinopWithPhiOperands(BinaryOperator &BO); + + /// Given an instruction with a select as one operand and a constant as the + /// other operand, try to fold the binary operator into the select arguments. + /// This also works for Cast instructions, which obviously do not have a + /// second operand. + Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI, + bool FoldWithMultiUse = false); + + /// This is a convenience wrapper function for the above two functions. + Instruction *foldBinOpIntoSelectOrPhi(BinaryOperator &I); + + Instruction *foldAddWithConstant(BinaryOperator &Add); + + /// Try to rotate an operation below a PHI node, using PHI nodes for + /// its operands. + Instruction *foldPHIArgOpIntoPHI(PHINode &PN); + Instruction *foldPHIArgBinOpIntoPHI(PHINode &PN); + Instruction *foldPHIArgInsertValueInstructionIntoPHI(PHINode &PN); + Instruction *foldPHIArgExtractValueInstructionIntoPHI(PHINode &PN); + Instruction *foldPHIArgGEPIntoPHI(PHINode &PN); + Instruction *foldPHIArgLoadIntoPHI(PHINode &PN); + Instruction *foldPHIArgZextsIntoPHI(PHINode &PN); + Instruction *foldPHIArgIntToPtrToPHI(PHINode &PN); + + /// If an integer typed PHI has only one use which is an IntToPtr operation, + /// replace the PHI with an existing pointer typed PHI if it exists. Otherwise + /// insert a new pointer typed PHI and replace the original one. + Instruction *foldIntegerTypedPHI(PHINode &PN); + + /// Helper function for FoldPHIArgXIntoPHI() to set debug location for the + /// folded operation. + void PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN); + + Instruction *foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, + ICmpInst::Predicate Cond, Instruction &I); + Instruction *foldSelectICmp(ICmpInst::Predicate Pred, SelectInst *SI, + Value *RHS, const ICmpInst &I); + Instruction *foldAllocaCmp(ICmpInst &ICI, const AllocaInst *Alloca); + Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI, + GetElementPtrInst *GEP, + GlobalVariable *GV, CmpInst &ICI, + ConstantInt *AndCst = nullptr); + Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, + Constant *RHSC); + Instruction *foldICmpAddOpConst(Value *X, const APInt &C, + ICmpInst::Predicate Pred); + Instruction *foldICmpWithCastOp(ICmpInst &ICmp); + Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp); + + Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); + Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp); + Instruction *foldICmpWithConstant(ICmpInst &Cmp); + Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); + Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); + Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, + const APInt &C); + Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); + Instruction *foldICmpEquality(ICmpInst &Cmp); + Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I); + Instruction *foldSignBitTest(ICmpInst &I); + Instruction *foldICmpWithZero(ICmpInst &Cmp); + + Value *foldMultiplicationOverflowCheck(ICmpInst &Cmp); + + Instruction *foldICmpBinOpWithConstant(ICmpInst &Cmp, BinaryOperator *BO, + const APInt &C); + Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, + ConstantInt *C); + Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, + const APInt &C); + Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, + const APInt &C); + Instruction *foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, + const APInt &C); + Instruction *foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, + const APInt &C); + Instruction *foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul, + const APInt &C); + Instruction *foldICmpShlConstant(ICmpInst &Cmp, BinaryOperator *Shl, + const APInt &C); + Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr, + const APInt &C); + Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv, + const APInt &C); + Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, + const APInt &C); + Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, + const APInt &C); + Instruction *foldICmpSubConstant(ICmpInst &Cmp, BinaryOperator *Sub, + const APInt &C); + Instruction *foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, + const APInt &C); + Instruction *foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, + const APInt &C1); + Instruction *foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, + const APInt &C1, const APInt &C2); + Instruction *foldICmpShrConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, + const APInt &C2); + Instruction *foldICmpShlConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, + const APInt &C2); + + Instruction *foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, + BinaryOperator *BO, + const APInt &C); + Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI, IntrinsicInst *II, + const APInt &C); + Instruction *foldICmpEqIntrinsicWithConstant(ICmpInst &ICI, IntrinsicInst *II, + const APInt &C); + Instruction *foldICmpBitCast(ICmpInst &Cmp); + + // Helpers of visitSelectInst(). + Instruction *foldSelectExtConst(SelectInst &Sel); + Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI); + Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *); + Instruction *foldSPFofSPF(Instruction *Inner, SelectPatternFlavor SPF1, + Value *A, Value *B, Instruction &Outer, + SelectPatternFlavor SPF2, Value *C); + Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); + Instruction *foldSelectValueEquivalence(SelectInst &SI, ICmpInst &ICI); + + Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, + bool isSigned, bool Inside); + Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI); + bool mergeStoreIntoSuccessor(StoreInst &SI); + + /// Given an initial instruction, check to see if it is the root of a + /// bswap/bitreverse idiom. If so, return the equivalent bswap/bitreverse + /// intrinsic. + Instruction *matchBSwapOrBitReverse(Instruction &I, bool MatchBSwaps, + bool MatchBitReversals); + + Instruction *SimplifyAnyMemTransfer(AnyMemTransferInst *MI); + Instruction *SimplifyAnyMemSet(AnyMemSetInst *MI); + + Value *EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned); + + /// Returns a value X such that Val = X * Scale, or null if none. + /// + /// If the multiplication is known not to overflow then NoSignedWrap is set. + Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap); +}; + +class Negator final { + /// Top-to-bottom, def-to-use negated instruction tree we produced. + SmallVector<Instruction *, NegatorMaxNodesSSO> NewInstructions; + + using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>; + BuilderTy Builder; + + const DataLayout &DL; + AssumptionCache &AC; + const DominatorTree &DT; + + const bool IsTrulyNegation; + + SmallDenseMap<Value *, Value *> NegationsCache; + + Negator(LLVMContext &C, const DataLayout &DL, AssumptionCache &AC, + const DominatorTree &DT, bool IsTrulyNegation); + +#if LLVM_ENABLE_STATS + unsigned NumValuesVisitedInThisNegator = 0; + ~Negator(); +#endif + + using Result = std::pair<ArrayRef<Instruction *> /*NewInstructions*/, + Value * /*NegatedRoot*/>; + + std::array<Value *, 2> getSortedOperandsOfBinOp(Instruction *I); + + LLVM_NODISCARD Value *visitImpl(Value *V, unsigned Depth); + + LLVM_NODISCARD Value *negate(Value *V, unsigned Depth); + + /// Recurse depth-first and attempt to sink the negation. + /// FIXME: use worklist? + LLVM_NODISCARD Optional<Result> run(Value *Root); + + Negator(const Negator &) = delete; + Negator(Negator &&) = delete; + Negator &operator=(const Negator &) = delete; + Negator &operator=(Negator &&) = delete; + +public: + /// Attempt to negate \p Root. Retuns nullptr if negation can't be performed, + /// otherwise returns negated value. + LLVM_NODISCARD static Value *Negate(bool LHSIsZero, Value *Root, + InstCombinerImpl &IC); +}; + +} // end namespace llvm + +#undef DEBUG_TYPE + +#endif // LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp new file mode 100644 index 000000000000..e03b7026f802 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -0,0 +1,1560 @@ +//===- InstCombineLoadStoreAlloca.cpp -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the visit functions for load, store and alloca. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +#include "llvm/Transforms/Utils/Local.h" +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "instcombine" + +STATISTIC(NumDeadStore, "Number of dead stores eliminated"); +STATISTIC(NumGlobalCopies, "Number of allocas copied from constant global"); + +/// isOnlyCopiedFromConstantGlobal - Recursively walk the uses of a (derived) +/// pointer to an alloca. Ignore any reads of the pointer, return false if we +/// see any stores or other unknown uses. If we see pointer arithmetic, keep +/// track of whether it moves the pointer (with IsOffset) but otherwise traverse +/// the uses. If we see a memcpy/memmove that targets an unoffseted pointer to +/// the alloca, and if the source pointer is a pointer to a constant global, we +/// can optimize this. +static bool +isOnlyCopiedFromConstantMemory(AAResults *AA, + Value *V, MemTransferInst *&TheCopy, + SmallVectorImpl<Instruction *> &ToDelete) { + // We track lifetime intrinsics as we encounter them. If we decide to go + // ahead and replace the value with the global, this lets the caller quickly + // eliminate the markers. + + SmallVector<std::pair<Value *, bool>, 35> ValuesToInspect; + ValuesToInspect.emplace_back(V, false); + while (!ValuesToInspect.empty()) { + auto ValuePair = ValuesToInspect.pop_back_val(); + const bool IsOffset = ValuePair.second; + for (auto &U : ValuePair.first->uses()) { + auto *I = cast<Instruction>(U.getUser()); + + if (auto *LI = dyn_cast<LoadInst>(I)) { + // Ignore non-volatile loads, they are always ok. + if (!LI->isSimple()) return false; + continue; + } + + if (isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I)) { + // If uses of the bitcast are ok, we are ok. + ValuesToInspect.emplace_back(I, IsOffset); + continue; + } + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { + // If the GEP has all zero indices, it doesn't offset the pointer. If it + // doesn't, it does. + ValuesToInspect.emplace_back(I, IsOffset || !GEP->hasAllZeroIndices()); + continue; + } + + if (auto *Call = dyn_cast<CallBase>(I)) { + // If this is the function being called then we treat it like a load and + // ignore it. + if (Call->isCallee(&U)) + continue; + + unsigned DataOpNo = Call->getDataOperandNo(&U); + bool IsArgOperand = Call->isArgOperand(&U); + + // Inalloca arguments are clobbered by the call. + if (IsArgOperand && Call->isInAllocaArgument(DataOpNo)) + return false; + + // If this is a readonly/readnone call site, then we know it is just a + // load (but one that potentially returns the value itself), so we can + // ignore it if we know that the value isn't captured. + if (Call->onlyReadsMemory() && + (Call->use_empty() || Call->doesNotCapture(DataOpNo))) + continue; + + // If this is being passed as a byval argument, the caller is making a + // copy, so it is only a read of the alloca. + if (IsArgOperand && Call->isByValArgument(DataOpNo)) + continue; + } + + // Lifetime intrinsics can be handled by the caller. + if (I->isLifetimeStartOrEnd()) { + assert(I->use_empty() && "Lifetime markers have no result to use!"); + ToDelete.push_back(I); + continue; + } + + // If this is isn't our memcpy/memmove, reject it as something we can't + // handle. + MemTransferInst *MI = dyn_cast<MemTransferInst>(I); + if (!MI) + return false; + + // If the transfer is using the alloca as a source of the transfer, then + // ignore it since it is a load (unless the transfer is volatile). + if (U.getOperandNo() == 1) { + if (MI->isVolatile()) return false; + continue; + } + + // If we already have seen a copy, reject the second one. + if (TheCopy) return false; + + // If the pointer has been offset from the start of the alloca, we can't + // safely handle this. + if (IsOffset) return false; + + // If the memintrinsic isn't using the alloca as the dest, reject it. + if (U.getOperandNo() != 0) return false; + + // If the source of the memcpy/move is not a constant global, reject it. + if (!AA->pointsToConstantMemory(MI->getSource())) + return false; + + // Otherwise, the transform is safe. Remember the copy instruction. + TheCopy = MI; + } + } + return true; +} + +/// isOnlyCopiedFromConstantGlobal - Return true if the specified alloca is only +/// modified by a copy from a constant global. If we can prove this, we can +/// replace any uses of the alloca with uses of the global directly. +static MemTransferInst * +isOnlyCopiedFromConstantMemory(AAResults *AA, + AllocaInst *AI, + SmallVectorImpl<Instruction *> &ToDelete) { + MemTransferInst *TheCopy = nullptr; + if (isOnlyCopiedFromConstantMemory(AA, AI, TheCopy, ToDelete)) + return TheCopy; + return nullptr; +} + +/// Returns true if V is dereferenceable for size of alloca. +static bool isDereferenceableForAllocaSize(const Value *V, const AllocaInst *AI, + const DataLayout &DL) { + if (AI->isArrayAllocation()) + return false; + uint64_t AllocaSize = DL.getTypeStoreSize(AI->getAllocatedType()); + if (!AllocaSize) + return false; + return isDereferenceableAndAlignedPointer(V, AI->getAlign(), + APInt(64, AllocaSize), DL); +} + +static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC, + AllocaInst &AI) { + // Check for array size of 1 (scalar allocation). + if (!AI.isArrayAllocation()) { + // i32 1 is the canonical array size for scalar allocations. + if (AI.getArraySize()->getType()->isIntegerTy(32)) + return nullptr; + + // Canonicalize it. + return IC.replaceOperand(AI, 0, IC.Builder.getInt32(1)); + } + + // Convert: alloca Ty, C - where C is a constant != 1 into: alloca [C x Ty], 1 + if (const ConstantInt *C = dyn_cast<ConstantInt>(AI.getArraySize())) { + if (C->getValue().getActiveBits() <= 64) { + Type *NewTy = ArrayType::get(AI.getAllocatedType(), C->getZExtValue()); + AllocaInst *New = IC.Builder.CreateAlloca(NewTy, AI.getAddressSpace(), + nullptr, AI.getName()); + New->setAlignment(AI.getAlign()); + + // Scan to the end of the allocation instructions, to skip over a block of + // allocas if possible...also skip interleaved debug info + // + BasicBlock::iterator It(New); + while (isa<AllocaInst>(*It) || isa<DbgInfoIntrinsic>(*It)) + ++It; + + // Now that I is pointing to the first non-allocation-inst in the block, + // insert our getelementptr instruction... + // + Type *IdxTy = IC.getDataLayout().getIntPtrType(AI.getType()); + Value *NullIdx = Constant::getNullValue(IdxTy); + Value *Idx[2] = {NullIdx, NullIdx}; + Instruction *GEP = GetElementPtrInst::CreateInBounds( + NewTy, New, Idx, New->getName() + ".sub"); + IC.InsertNewInstBefore(GEP, *It); + + // Now make everything use the getelementptr instead of the original + // allocation. + return IC.replaceInstUsesWith(AI, GEP); + } + } + + if (isa<UndefValue>(AI.getArraySize())) + return IC.replaceInstUsesWith(AI, Constant::getNullValue(AI.getType())); + + // Ensure that the alloca array size argument has type intptr_t, so that + // any casting is exposed early. + Type *IntPtrTy = IC.getDataLayout().getIntPtrType(AI.getType()); + if (AI.getArraySize()->getType() != IntPtrTy) { + Value *V = IC.Builder.CreateIntCast(AI.getArraySize(), IntPtrTy, false); + return IC.replaceOperand(AI, 0, V); + } + + return nullptr; +} + +namespace { +// If I and V are pointers in different address space, it is not allowed to +// use replaceAllUsesWith since I and V have different types. A +// non-target-specific transformation should not use addrspacecast on V since +// the two address space may be disjoint depending on target. +// +// This class chases down uses of the old pointer until reaching the load +// instructions, then replaces the old pointer in the load instructions with +// the new pointer. If during the chasing it sees bitcast or GEP, it will +// create new bitcast or GEP with the new pointer and use them in the load +// instruction. +class PointerReplacer { +public: + PointerReplacer(InstCombinerImpl &IC) : IC(IC) {} + + bool collectUsers(Instruction &I); + void replacePointer(Instruction &I, Value *V); + +private: + void replace(Instruction *I); + Value *getReplacement(Value *I); + + SmallSetVector<Instruction *, 4> Worklist; + MapVector<Value *, Value *> WorkMap; + InstCombinerImpl &IC; +}; +} // end anonymous namespace + +bool PointerReplacer::collectUsers(Instruction &I) { + for (auto U : I.users()) { + auto *Inst = cast<Instruction>(&*U); + if (auto *Load = dyn_cast<LoadInst>(Inst)) { + if (Load->isVolatile()) + return false; + Worklist.insert(Load); + } else if (isa<GetElementPtrInst>(Inst) || isa<BitCastInst>(Inst)) { + Worklist.insert(Inst); + if (!collectUsers(*Inst)) + return false; + } else if (auto *MI = dyn_cast<MemTransferInst>(Inst)) { + if (MI->isVolatile()) + return false; + Worklist.insert(Inst); + } else if (Inst->isLifetimeStartOrEnd()) { + continue; + } else { + LLVM_DEBUG(dbgs() << "Cannot handle pointer user: " << *U << '\n'); + return false; + } + } + + return true; +} + +Value *PointerReplacer::getReplacement(Value *V) { return WorkMap.lookup(V); } + +void PointerReplacer::replace(Instruction *I) { + if (getReplacement(I)) + return; + + if (auto *LT = dyn_cast<LoadInst>(I)) { + auto *V = getReplacement(LT->getPointerOperand()); + assert(V && "Operand not replaced"); + auto *NewI = new LoadInst(LT->getType(), V, "", LT->isVolatile(), + LT->getAlign(), LT->getOrdering(), + LT->getSyncScopeID()); + NewI->takeName(LT); + copyMetadataForLoad(*NewI, *LT); + + IC.InsertNewInstWith(NewI, *LT); + IC.replaceInstUsesWith(*LT, NewI); + WorkMap[LT] = NewI; + } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { + auto *V = getReplacement(GEP->getPointerOperand()); + assert(V && "Operand not replaced"); + SmallVector<Value *, 8> Indices; + Indices.append(GEP->idx_begin(), GEP->idx_end()); + auto *NewI = + GetElementPtrInst::Create(GEP->getSourceElementType(), V, Indices); + IC.InsertNewInstWith(NewI, *GEP); + NewI->takeName(GEP); + WorkMap[GEP] = NewI; + } else if (auto *BC = dyn_cast<BitCastInst>(I)) { + auto *V = getReplacement(BC->getOperand(0)); + assert(V && "Operand not replaced"); + auto *NewT = PointerType::getWithSamePointeeType( + cast<PointerType>(BC->getType()), + V->getType()->getPointerAddressSpace()); + auto *NewI = new BitCastInst(V, NewT); + IC.InsertNewInstWith(NewI, *BC); + NewI->takeName(BC); + WorkMap[BC] = NewI; + } else if (auto *MemCpy = dyn_cast<MemTransferInst>(I)) { + auto *SrcV = getReplacement(MemCpy->getRawSource()); + // The pointer may appear in the destination of a copy, but we don't want to + // replace it. + if (!SrcV) { + assert(getReplacement(MemCpy->getRawDest()) && + "destination not in replace list"); + return; + } + + IC.Builder.SetInsertPoint(MemCpy); + auto *NewI = IC.Builder.CreateMemTransferInst( + MemCpy->getIntrinsicID(), MemCpy->getRawDest(), MemCpy->getDestAlign(), + SrcV, MemCpy->getSourceAlign(), MemCpy->getLength(), + MemCpy->isVolatile()); + AAMDNodes AAMD = MemCpy->getAAMetadata(); + if (AAMD) + NewI->setAAMetadata(AAMD); + + IC.eraseInstFromFunction(*MemCpy); + WorkMap[MemCpy] = NewI; + } else { + llvm_unreachable("should never reach here"); + } +} + +void PointerReplacer::replacePointer(Instruction &I, Value *V) { +#ifndef NDEBUG + auto *PT = cast<PointerType>(I.getType()); + auto *NT = cast<PointerType>(V->getType()); + assert(PT != NT && PT->hasSameElementTypeAs(NT) && "Invalid usage"); +#endif + WorkMap[&I] = V; + + for (Instruction *Workitem : Worklist) + replace(Workitem); +} + +Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { + if (auto *I = simplifyAllocaArraySize(*this, AI)) + return I; + + if (AI.getAllocatedType()->isSized()) { + // Move all alloca's of zero byte objects to the entry block and merge them + // together. Note that we only do this for alloca's, because malloc should + // allocate and return a unique pointer, even for a zero byte allocation. + if (DL.getTypeAllocSize(AI.getAllocatedType()).getKnownMinSize() == 0) { + // For a zero sized alloca there is no point in doing an array allocation. + // This is helpful if the array size is a complicated expression not used + // elsewhere. + if (AI.isArrayAllocation()) + return replaceOperand(AI, 0, + ConstantInt::get(AI.getArraySize()->getType(), 1)); + + // Get the first instruction in the entry block. + BasicBlock &EntryBlock = AI.getParent()->getParent()->getEntryBlock(); + Instruction *FirstInst = EntryBlock.getFirstNonPHIOrDbg(); + if (FirstInst != &AI) { + // If the entry block doesn't start with a zero-size alloca then move + // this one to the start of the entry block. There is no problem with + // dominance as the array size was forced to a constant earlier already. + AllocaInst *EntryAI = dyn_cast<AllocaInst>(FirstInst); + if (!EntryAI || !EntryAI->getAllocatedType()->isSized() || + DL.getTypeAllocSize(EntryAI->getAllocatedType()) + .getKnownMinSize() != 0) { + AI.moveBefore(FirstInst); + return &AI; + } + + // Replace this zero-sized alloca with the one at the start of the entry + // block after ensuring that the address will be aligned enough for both + // types. + const Align MaxAlign = std::max(EntryAI->getAlign(), AI.getAlign()); + EntryAI->setAlignment(MaxAlign); + if (AI.getType() != EntryAI->getType()) + return new BitCastInst(EntryAI, AI.getType()); + return replaceInstUsesWith(AI, EntryAI); + } + } + } + + // Check to see if this allocation is only modified by a memcpy/memmove from + // a constant whose alignment is equal to or exceeds that of the allocation. + // If this is the case, we can change all users to use the constant global + // instead. This is commonly produced by the CFE by constructs like "void + // foo() { int A[] = {1,2,3,4,5,6,7,8,9...}; }" if 'A' is only subsequently + // read. + SmallVector<Instruction *, 4> ToDelete; + if (MemTransferInst *Copy = isOnlyCopiedFromConstantMemory(AA, &AI, ToDelete)) { + Value *TheSrc = Copy->getSource(); + Align AllocaAlign = AI.getAlign(); + Align SourceAlign = getOrEnforceKnownAlignment( + TheSrc, AllocaAlign, DL, &AI, &AC, &DT); + if (AllocaAlign <= SourceAlign && + isDereferenceableForAllocaSize(TheSrc, &AI, DL) && + !isa<Instruction>(TheSrc)) { + // FIXME: Can we sink instructions without violating dominance when TheSrc + // is an instruction instead of a constant or argument? + LLVM_DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); + LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); + unsigned SrcAddrSpace = TheSrc->getType()->getPointerAddressSpace(); + auto *DestTy = PointerType::get(AI.getAllocatedType(), SrcAddrSpace); + if (AI.getType()->getAddressSpace() == SrcAddrSpace) { + for (Instruction *Delete : ToDelete) + eraseInstFromFunction(*Delete); + + Value *Cast = Builder.CreateBitCast(TheSrc, DestTy); + Instruction *NewI = replaceInstUsesWith(AI, Cast); + eraseInstFromFunction(*Copy); + ++NumGlobalCopies; + return NewI; + } + + PointerReplacer PtrReplacer(*this); + if (PtrReplacer.collectUsers(AI)) { + for (Instruction *Delete : ToDelete) + eraseInstFromFunction(*Delete); + + Value *Cast = Builder.CreateBitCast(TheSrc, DestTy); + PtrReplacer.replacePointer(AI, Cast); + ++NumGlobalCopies; + } + } + } + + // At last, use the generic allocation site handler to aggressively remove + // unused allocas. + return visitAllocSite(AI); +} + +// Are we allowed to form a atomic load or store of this type? +static bool isSupportedAtomicType(Type *Ty) { + return Ty->isIntOrPtrTy() || Ty->isFloatingPointTy(); +} + +/// Helper to combine a load to a new type. +/// +/// This just does the work of combining a load to a new type. It handles +/// metadata, etc., and returns the new instruction. The \c NewTy should be the +/// loaded *value* type. This will convert it to a pointer, cast the operand to +/// that pointer type, load it, etc. +/// +/// Note that this will create all of the instructions with whatever insert +/// point the \c InstCombinerImpl currently is using. +LoadInst *InstCombinerImpl::combineLoadToNewType(LoadInst &LI, Type *NewTy, + const Twine &Suffix) { + assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) && + "can't fold an atomic load to requested type"); + + Value *Ptr = LI.getPointerOperand(); + unsigned AS = LI.getPointerAddressSpace(); + Type *NewPtrTy = NewTy->getPointerTo(AS); + Value *NewPtr = nullptr; + if (!(match(Ptr, m_BitCast(m_Value(NewPtr))) && + NewPtr->getType() == NewPtrTy)) + NewPtr = Builder.CreateBitCast(Ptr, NewPtrTy); + + LoadInst *NewLoad = Builder.CreateAlignedLoad( + NewTy, NewPtr, LI.getAlign(), LI.isVolatile(), LI.getName() + Suffix); + NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); + copyMetadataForLoad(*NewLoad, LI); + return NewLoad; +} + +/// Combine a store to a new type. +/// +/// Returns the newly created store instruction. +static StoreInst *combineStoreToNewValue(InstCombinerImpl &IC, StoreInst &SI, + Value *V) { + assert((!SI.isAtomic() || isSupportedAtomicType(V->getType())) && + "can't fold an atomic store of requested type"); + + Value *Ptr = SI.getPointerOperand(); + unsigned AS = SI.getPointerAddressSpace(); + SmallVector<std::pair<unsigned, MDNode *>, 8> MD; + SI.getAllMetadata(MD); + + StoreInst *NewStore = IC.Builder.CreateAlignedStore( + V, IC.Builder.CreateBitCast(Ptr, V->getType()->getPointerTo(AS)), + SI.getAlign(), SI.isVolatile()); + NewStore->setAtomic(SI.getOrdering(), SI.getSyncScopeID()); + for (const auto &MDPair : MD) { + unsigned ID = MDPair.first; + MDNode *N = MDPair.second; + // Note, essentially every kind of metadata should be preserved here! This + // routine is supposed to clone a store instruction changing *only its + // type*. The only metadata it makes sense to drop is metadata which is + // invalidated when the pointer type changes. This should essentially + // never be the case in LLVM, but we explicitly switch over only known + // metadata to be conservatively correct. If you are adding metadata to + // LLVM which pertains to stores, you almost certainly want to add it + // here. + switch (ID) { + case LLVMContext::MD_dbg: + case LLVMContext::MD_tbaa: + case LLVMContext::MD_prof: + case LLVMContext::MD_fpmath: + case LLVMContext::MD_tbaa_struct: + case LLVMContext::MD_alias_scope: + case LLVMContext::MD_noalias: + case LLVMContext::MD_nontemporal: + case LLVMContext::MD_mem_parallel_loop_access: + case LLVMContext::MD_access_group: + // All of these directly apply. + NewStore->setMetadata(ID, N); + break; + case LLVMContext::MD_invariant_load: + case LLVMContext::MD_nonnull: + case LLVMContext::MD_noundef: + case LLVMContext::MD_range: + case LLVMContext::MD_align: + case LLVMContext::MD_dereferenceable: + case LLVMContext::MD_dereferenceable_or_null: + // These don't apply for stores. + break; + } + } + + return NewStore; +} + +/// Returns true if instruction represent minmax pattern like: +/// select ((cmp load V1, load V2), V1, V2). +static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) { + assert(V->getType()->isPointerTy() && "Expected pointer type."); + // Ignore possible ty* to ixx* bitcast. + V = InstCombiner::peekThroughBitcast(V); + // Check that select is select ((cmp load V1, load V2), V1, V2) - minmax + // pattern. + CmpInst::Predicate Pred; + Instruction *L1; + Instruction *L2; + Value *LHS; + Value *RHS; + if (!match(V, m_Select(m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2)), + m_Value(LHS), m_Value(RHS)))) + return false; + LoadTy = L1->getType(); + return (match(L1, m_Load(m_Specific(LHS))) && + match(L2, m_Load(m_Specific(RHS)))) || + (match(L1, m_Load(m_Specific(RHS))) && + match(L2, m_Load(m_Specific(LHS)))); +} + +/// Combine loads to match the type of their uses' value after looking +/// through intervening bitcasts. +/// +/// The core idea here is that if the result of a load is used in an operation, +/// we should load the type most conducive to that operation. For example, when +/// loading an integer and converting that immediately to a pointer, we should +/// instead directly load a pointer. +/// +/// However, this routine must never change the width of a load or the number of +/// loads as that would introduce a semantic change. This combine is expected to +/// be a semantic no-op which just allows loads to more closely model the types +/// of their consuming operations. +/// +/// Currently, we also refuse to change the precise type used for an atomic load +/// or a volatile load. This is debatable, and might be reasonable to change +/// later. However, it is risky in case some backend or other part of LLVM is +/// relying on the exact type loaded to select appropriate atomic operations. +static Instruction *combineLoadToOperationType(InstCombinerImpl &IC, + LoadInst &LI) { + // FIXME: We could probably with some care handle both volatile and ordered + // atomic loads here but it isn't clear that this is important. + if (!LI.isUnordered()) + return nullptr; + + if (LI.use_empty()) + return nullptr; + + // swifterror values can't be bitcasted. + if (LI.getPointerOperand()->isSwiftError()) + return nullptr; + + const DataLayout &DL = IC.getDataLayout(); + + // Fold away bit casts of the loaded value by loading the desired type. + // Note that we should not do this for pointer<->integer casts, + // because that would result in type punning. + if (LI.hasOneUse()) { + // Don't transform when the type is x86_amx, it makes the pass that lower + // x86_amx type happy. + if (auto *BC = dyn_cast<BitCastInst>(LI.user_back())) { + assert(!LI.getType()->isX86_AMXTy() && + "load from x86_amx* should not happen!"); + if (BC->getType()->isX86_AMXTy()) + return nullptr; + } + + if (auto* CI = dyn_cast<CastInst>(LI.user_back())) + if (CI->isNoopCast(DL) && LI.getType()->isPtrOrPtrVectorTy() == + CI->getDestTy()->isPtrOrPtrVectorTy()) + if (!LI.isAtomic() || isSupportedAtomicType(CI->getDestTy())) { + LoadInst *NewLoad = IC.combineLoadToNewType(LI, CI->getDestTy()); + CI->replaceAllUsesWith(NewLoad); + IC.eraseInstFromFunction(*CI); + return &LI; + } + } + + // FIXME: We should also canonicalize loads of vectors when their elements are + // cast to other types. + return nullptr; +} + +static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { + // FIXME: We could probably with some care handle both volatile and atomic + // stores here but it isn't clear that this is important. + if (!LI.isSimple()) + return nullptr; + + Type *T = LI.getType(); + if (!T->isAggregateType()) + return nullptr; + + StringRef Name = LI.getName(); + + if (auto *ST = dyn_cast<StructType>(T)) { + // If the struct only have one element, we unpack. + auto NumElements = ST->getNumElements(); + if (NumElements == 1) { + LoadInst *NewLoad = IC.combineLoadToNewType(LI, ST->getTypeAtIndex(0U), + ".unpack"); + NewLoad->setAAMetadata(LI.getAAMetadata()); + return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue( + UndefValue::get(T), NewLoad, 0, Name)); + } + + // We don't want to break loads with padding here as we'd loose + // the knowledge that padding exists for the rest of the pipeline. + const DataLayout &DL = IC.getDataLayout(); + auto *SL = DL.getStructLayout(ST); + if (SL->hasPadding()) + return nullptr; + + const auto Align = LI.getAlign(); + auto *Addr = LI.getPointerOperand(); + auto *IdxType = Type::getInt32Ty(T->getContext()); + auto *Zero = ConstantInt::get(IdxType, 0); + + Value *V = UndefValue::get(T); + for (unsigned i = 0; i < NumElements; i++) { + Value *Indices[2] = { + Zero, + ConstantInt::get(IdxType, i), + }; + auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), + Name + ".elt"); + auto *L = IC.Builder.CreateAlignedLoad( + ST->getElementType(i), Ptr, + commonAlignment(Align, SL->getElementOffset(i)), Name + ".unpack"); + // Propagate AA metadata. It'll still be valid on the narrowed load. + L->setAAMetadata(LI.getAAMetadata()); + V = IC.Builder.CreateInsertValue(V, L, i); + } + + V->setName(Name); + return IC.replaceInstUsesWith(LI, V); + } + + if (auto *AT = dyn_cast<ArrayType>(T)) { + auto *ET = AT->getElementType(); + auto NumElements = AT->getNumElements(); + if (NumElements == 1) { + LoadInst *NewLoad = IC.combineLoadToNewType(LI, ET, ".unpack"); + NewLoad->setAAMetadata(LI.getAAMetadata()); + return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue( + UndefValue::get(T), NewLoad, 0, Name)); + } + + // Bail out if the array is too large. Ideally we would like to optimize + // arrays of arbitrary size but this has a terrible impact on compile time. + // The threshold here is chosen arbitrarily, maybe needs a little bit of + // tuning. + if (NumElements > IC.MaxArraySizeForCombine) + return nullptr; + + const DataLayout &DL = IC.getDataLayout(); + auto EltSize = DL.getTypeAllocSize(ET); + const auto Align = LI.getAlign(); + + auto *Addr = LI.getPointerOperand(); + auto *IdxType = Type::getInt64Ty(T->getContext()); + auto *Zero = ConstantInt::get(IdxType, 0); + + Value *V = UndefValue::get(T); + uint64_t Offset = 0; + for (uint64_t i = 0; i < NumElements; i++) { + Value *Indices[2] = { + Zero, + ConstantInt::get(IdxType, i), + }; + auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), + Name + ".elt"); + auto *L = IC.Builder.CreateAlignedLoad(AT->getElementType(), Ptr, + commonAlignment(Align, Offset), + Name + ".unpack"); + L->setAAMetadata(LI.getAAMetadata()); + V = IC.Builder.CreateInsertValue(V, L, i); + Offset += EltSize; + } + + V->setName(Name); + return IC.replaceInstUsesWith(LI, V); + } + + return nullptr; +} + +// If we can determine that all possible objects pointed to by the provided +// pointer value are, not only dereferenceable, but also definitively less than +// or equal to the provided maximum size, then return true. Otherwise, return +// false (constant global values and allocas fall into this category). +// +// FIXME: This should probably live in ValueTracking (or similar). +static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize, + const DataLayout &DL) { + SmallPtrSet<Value *, 4> Visited; + SmallVector<Value *, 4> Worklist(1, V); + + do { + Value *P = Worklist.pop_back_val(); + P = P->stripPointerCasts(); + + if (!Visited.insert(P).second) + continue; + + if (SelectInst *SI = dyn_cast<SelectInst>(P)) { + Worklist.push_back(SI->getTrueValue()); + Worklist.push_back(SI->getFalseValue()); + continue; + } + + if (PHINode *PN = dyn_cast<PHINode>(P)) { + append_range(Worklist, PN->incoming_values()); + continue; + } + + if (GlobalAlias *GA = dyn_cast<GlobalAlias>(P)) { + if (GA->isInterposable()) + return false; + Worklist.push_back(GA->getAliasee()); + continue; + } + + // If we know how big this object is, and it is less than MaxSize, continue + // searching. Otherwise, return false. + if (AllocaInst *AI = dyn_cast<AllocaInst>(P)) { + if (!AI->getAllocatedType()->isSized()) + return false; + + ConstantInt *CS = dyn_cast<ConstantInt>(AI->getArraySize()); + if (!CS) + return false; + + uint64_t TypeSize = DL.getTypeAllocSize(AI->getAllocatedType()); + // Make sure that, even if the multiplication below would wrap as an + // uint64_t, we still do the right thing. + if ((CS->getValue().zext(128) * APInt(128, TypeSize)).ugt(MaxSize)) + return false; + continue; + } + + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(P)) { + if (!GV->hasDefinitiveInitializer() || !GV->isConstant()) + return false; + + uint64_t InitSize = DL.getTypeAllocSize(GV->getValueType()); + if (InitSize > MaxSize) + return false; + continue; + } + + return false; + } while (!Worklist.empty()); + + return true; +} + +// If we're indexing into an object of a known size, and the outer index is +// not a constant, but having any value but zero would lead to undefined +// behavior, replace it with zero. +// +// For example, if we have: +// @f.a = private unnamed_addr constant [1 x i32] [i32 12], align 4 +// ... +// %arrayidx = getelementptr inbounds [1 x i32]* @f.a, i64 0, i64 %x +// ... = load i32* %arrayidx, align 4 +// Then we know that we can replace %x in the GEP with i64 0. +// +// FIXME: We could fold any GEP index to zero that would cause UB if it were +// not zero. Currently, we only handle the first such index. Also, we could +// also search through non-zero constant indices if we kept track of the +// offsets those indices implied. +static bool canReplaceGEPIdxWithZero(InstCombinerImpl &IC, + GetElementPtrInst *GEPI, Instruction *MemI, + unsigned &Idx) { + if (GEPI->getNumOperands() < 2) + return false; + + // Find the first non-zero index of a GEP. If all indices are zero, return + // one past the last index. + auto FirstNZIdx = [](const GetElementPtrInst *GEPI) { + unsigned I = 1; + for (unsigned IE = GEPI->getNumOperands(); I != IE; ++I) { + Value *V = GEPI->getOperand(I); + if (const ConstantInt *CI = dyn_cast<ConstantInt>(V)) + if (CI->isZero()) + continue; + + break; + } + + return I; + }; + + // Skip through initial 'zero' indices, and find the corresponding pointer + // type. See if the next index is not a constant. + Idx = FirstNZIdx(GEPI); + if (Idx == GEPI->getNumOperands()) + return false; + if (isa<Constant>(GEPI->getOperand(Idx))) + return false; + + SmallVector<Value *, 4> Ops(GEPI->idx_begin(), GEPI->idx_begin() + Idx); + Type *SourceElementType = GEPI->getSourceElementType(); + // Size information about scalable vectors is not available, so we cannot + // deduce whether indexing at n is undefined behaviour or not. Bail out. + if (isa<ScalableVectorType>(SourceElementType)) + return false; + + Type *AllocTy = GetElementPtrInst::getIndexedType(SourceElementType, Ops); + if (!AllocTy || !AllocTy->isSized()) + return false; + const DataLayout &DL = IC.getDataLayout(); + uint64_t TyAllocSize = DL.getTypeAllocSize(AllocTy).getFixedSize(); + + // If there are more indices after the one we might replace with a zero, make + // sure they're all non-negative. If any of them are negative, the overall + // address being computed might be before the base address determined by the + // first non-zero index. + auto IsAllNonNegative = [&]() { + for (unsigned i = Idx+1, e = GEPI->getNumOperands(); i != e; ++i) { + KnownBits Known = IC.computeKnownBits(GEPI->getOperand(i), 0, MemI); + if (Known.isNonNegative()) + continue; + return false; + } + + return true; + }; + + // FIXME: If the GEP is not inbounds, and there are extra indices after the + // one we'll replace, those could cause the address computation to wrap + // (rendering the IsAllNonNegative() check below insufficient). We can do + // better, ignoring zero indices (and other indices we can prove small + // enough not to wrap). + if (Idx+1 != GEPI->getNumOperands() && !GEPI->isInBounds()) + return false; + + // Note that isObjectSizeLessThanOrEq will return true only if the pointer is + // also known to be dereferenceable. + return isObjectSizeLessThanOrEq(GEPI->getOperand(0), TyAllocSize, DL) && + IsAllNonNegative(); +} + +// If we're indexing into an object with a variable index for the memory +// access, but the object has only one element, we can assume that the index +// will always be zero. If we replace the GEP, return it. +template <typename T> +static Instruction *replaceGEPIdxWithZero(InstCombinerImpl &IC, Value *Ptr, + T &MemI) { + if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Ptr)) { + unsigned Idx; + if (canReplaceGEPIdxWithZero(IC, GEPI, &MemI, Idx)) { + Instruction *NewGEPI = GEPI->clone(); + NewGEPI->setOperand(Idx, + ConstantInt::get(GEPI->getOperand(Idx)->getType(), 0)); + NewGEPI->insertBefore(GEPI); + MemI.setOperand(MemI.getPointerOperandIndex(), NewGEPI); + return NewGEPI; + } + } + + return nullptr; +} + +static bool canSimplifyNullStoreOrGEP(StoreInst &SI) { + if (NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace())) + return false; + + auto *Ptr = SI.getPointerOperand(); + if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Ptr)) + Ptr = GEPI->getOperand(0); + return (isa<ConstantPointerNull>(Ptr) && + !NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace())); +} + +static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) { + if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { + const Value *GEPI0 = GEPI->getOperand(0); + if (isa<ConstantPointerNull>(GEPI0) && + !NullPointerIsDefined(LI.getFunction(), GEPI->getPointerAddressSpace())) + return true; + } + if (isa<UndefValue>(Op) || + (isa<ConstantPointerNull>(Op) && + !NullPointerIsDefined(LI.getFunction(), LI.getPointerAddressSpace()))) + return true; + return false; +} + +Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) { + Value *Op = LI.getOperand(0); + + // Try to canonicalize the loaded type. + if (Instruction *Res = combineLoadToOperationType(*this, LI)) + return Res; + + // Attempt to improve the alignment. + Align KnownAlign = getOrEnforceKnownAlignment( + Op, DL.getPrefTypeAlign(LI.getType()), DL, &LI, &AC, &DT); + if (KnownAlign > LI.getAlign()) + LI.setAlignment(KnownAlign); + + // Replace GEP indices if possible. + if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI)) { + Worklist.push(NewGEPI); + return &LI; + } + + if (Instruction *Res = unpackLoadToAggregate(*this, LI)) + return Res; + + // Do really simple store-to-load forwarding and load CSE, to catch cases + // where there are several consecutive memory accesses to the same location, + // separated by a few arithmetic operations. + bool IsLoadCSE = false; + if (Value *AvailableVal = FindAvailableLoadedValue(&LI, *AA, &IsLoadCSE)) { + if (IsLoadCSE) + combineMetadataForCSE(cast<LoadInst>(AvailableVal), &LI, false); + + return replaceInstUsesWith( + LI, Builder.CreateBitOrPointerCast(AvailableVal, LI.getType(), + LI.getName() + ".cast")); + } + + // None of the following transforms are legal for volatile/ordered atomic + // loads. Most of them do apply for unordered atomics. + if (!LI.isUnordered()) return nullptr; + + // load(gep null, ...) -> unreachable + // load null/undef -> unreachable + // TODO: Consider a target hook for valid address spaces for this xforms. + if (canSimplifyNullLoadOrGEP(LI, Op)) { + // Insert a new store to null instruction before the load to indicate + // that this code is not reachable. We do this instead of inserting + // an unreachable instruction directly because we cannot modify the + // CFG. + StoreInst *SI = new StoreInst(PoisonValue::get(LI.getType()), + Constant::getNullValue(Op->getType()), &LI); + SI->setDebugLoc(LI.getDebugLoc()); + return replaceInstUsesWith(LI, PoisonValue::get(LI.getType())); + } + + if (Op->hasOneUse()) { + // Change select and PHI nodes to select values instead of addresses: this + // helps alias analysis out a lot, allows many others simplifications, and + // exposes redundancy in the code. + // + // Note that we cannot do the transformation unless we know that the + // introduced loads cannot trap! Something like this is valid as long as + // the condition is always false: load (select bool %C, int* null, int* %G), + // but it would not be valid if we transformed it to load from null + // unconditionally. + // + if (SelectInst *SI = dyn_cast<SelectInst>(Op)) { + // load (select (Cond, &V1, &V2)) --> select(Cond, load &V1, load &V2). + Align Alignment = LI.getAlign(); + if (isSafeToLoadUnconditionally(SI->getOperand(1), LI.getType(), + Alignment, DL, SI) && + isSafeToLoadUnconditionally(SI->getOperand(2), LI.getType(), + Alignment, DL, SI)) { + LoadInst *V1 = + Builder.CreateLoad(LI.getType(), SI->getOperand(1), + SI->getOperand(1)->getName() + ".val"); + LoadInst *V2 = + Builder.CreateLoad(LI.getType(), SI->getOperand(2), + SI->getOperand(2)->getName() + ".val"); + assert(LI.isUnordered() && "implied by above"); + V1->setAlignment(Alignment); + V1->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); + V2->setAlignment(Alignment); + V2->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); + return SelectInst::Create(SI->getCondition(), V1, V2); + } + + // load (select (cond, null, P)) -> load P + if (isa<ConstantPointerNull>(SI->getOperand(1)) && + !NullPointerIsDefined(SI->getFunction(), + LI.getPointerAddressSpace())) + return replaceOperand(LI, 0, SI->getOperand(2)); + + // load (select (cond, P, null)) -> load P + if (isa<ConstantPointerNull>(SI->getOperand(2)) && + !NullPointerIsDefined(SI->getFunction(), + LI.getPointerAddressSpace())) + return replaceOperand(LI, 0, SI->getOperand(1)); + } + } + return nullptr; +} + +/// Look for extractelement/insertvalue sequence that acts like a bitcast. +/// +/// \returns underlying value that was "cast", or nullptr otherwise. +/// +/// For example, if we have: +/// +/// %E0 = extractelement <2 x double> %U, i32 0 +/// %V0 = insertvalue [2 x double] undef, double %E0, 0 +/// %E1 = extractelement <2 x double> %U, i32 1 +/// %V1 = insertvalue [2 x double] %V0, double %E1, 1 +/// +/// and the layout of a <2 x double> is isomorphic to a [2 x double], +/// then %V1 can be safely approximated by a conceptual "bitcast" of %U. +/// Note that %U may contain non-undef values where %V1 has undef. +static Value *likeBitCastFromVector(InstCombinerImpl &IC, Value *V) { + Value *U = nullptr; + while (auto *IV = dyn_cast<InsertValueInst>(V)) { + auto *E = dyn_cast<ExtractElementInst>(IV->getInsertedValueOperand()); + if (!E) + return nullptr; + auto *W = E->getVectorOperand(); + if (!U) + U = W; + else if (U != W) + return nullptr; + auto *CI = dyn_cast<ConstantInt>(E->getIndexOperand()); + if (!CI || IV->getNumIndices() != 1 || CI->getZExtValue() != *IV->idx_begin()) + return nullptr; + V = IV->getAggregateOperand(); + } + if (!match(V, m_Undef()) || !U) + return nullptr; + + auto *UT = cast<VectorType>(U->getType()); + auto *VT = V->getType(); + // Check that types UT and VT are bitwise isomorphic. + const auto &DL = IC.getDataLayout(); + if (DL.getTypeStoreSizeInBits(UT) != DL.getTypeStoreSizeInBits(VT)) { + return nullptr; + } + if (auto *AT = dyn_cast<ArrayType>(VT)) { + if (AT->getNumElements() != cast<FixedVectorType>(UT)->getNumElements()) + return nullptr; + } else { + auto *ST = cast<StructType>(VT); + if (ST->getNumElements() != cast<FixedVectorType>(UT)->getNumElements()) + return nullptr; + for (const auto *EltT : ST->elements()) { + if (EltT != UT->getElementType()) + return nullptr; + } + } + return U; +} + +/// Combine stores to match the type of value being stored. +/// +/// The core idea here is that the memory does not have any intrinsic type and +/// where we can we should match the type of a store to the type of value being +/// stored. +/// +/// However, this routine must never change the width of a store or the number of +/// stores as that would introduce a semantic change. This combine is expected to +/// be a semantic no-op which just allows stores to more closely model the types +/// of their incoming values. +/// +/// Currently, we also refuse to change the precise type used for an atomic or +/// volatile store. This is debatable, and might be reasonable to change later. +/// However, it is risky in case some backend or other part of LLVM is relying +/// on the exact type stored to select appropriate atomic operations. +/// +/// \returns true if the store was successfully combined away. This indicates +/// the caller must erase the store instruction. We have to let the caller erase +/// the store instruction as otherwise there is no way to signal whether it was +/// combined or not: IC.EraseInstFromFunction returns a null pointer. +static bool combineStoreToValueType(InstCombinerImpl &IC, StoreInst &SI) { + // FIXME: We could probably with some care handle both volatile and ordered + // atomic stores here but it isn't clear that this is important. + if (!SI.isUnordered()) + return false; + + // swifterror values can't be bitcasted. + if (SI.getPointerOperand()->isSwiftError()) + return false; + + Value *V = SI.getValueOperand(); + + // Fold away bit casts of the stored value by storing the original type. + if (auto *BC = dyn_cast<BitCastInst>(V)) { + assert(!BC->getType()->isX86_AMXTy() && + "store to x86_amx* should not happen!"); + V = BC->getOperand(0); + // Don't transform when the type is x86_amx, it makes the pass that lower + // x86_amx type happy. + if (V->getType()->isX86_AMXTy()) + return false; + if (!SI.isAtomic() || isSupportedAtomicType(V->getType())) { + combineStoreToNewValue(IC, SI, V); + return true; + } + } + + if (Value *U = likeBitCastFromVector(IC, V)) + if (!SI.isAtomic() || isSupportedAtomicType(U->getType())) { + combineStoreToNewValue(IC, SI, U); + return true; + } + + // FIXME: We should also canonicalize stores of vectors when their elements + // are cast to other types. + return false; +} + +static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { + // FIXME: We could probably with some care handle both volatile and atomic + // stores here but it isn't clear that this is important. + if (!SI.isSimple()) + return false; + + Value *V = SI.getValueOperand(); + Type *T = V->getType(); + + if (!T->isAggregateType()) + return false; + + if (auto *ST = dyn_cast<StructType>(T)) { + // If the struct only have one element, we unpack. + unsigned Count = ST->getNumElements(); + if (Count == 1) { + V = IC.Builder.CreateExtractValue(V, 0); + combineStoreToNewValue(IC, SI, V); + return true; + } + + // We don't want to break loads with padding here as we'd loose + // the knowledge that padding exists for the rest of the pipeline. + const DataLayout &DL = IC.getDataLayout(); + auto *SL = DL.getStructLayout(ST); + if (SL->hasPadding()) + return false; + + const auto Align = SI.getAlign(); + + SmallString<16> EltName = V->getName(); + EltName += ".elt"; + auto *Addr = SI.getPointerOperand(); + SmallString<16> AddrName = Addr->getName(); + AddrName += ".repack"; + + auto *IdxType = Type::getInt32Ty(ST->getContext()); + auto *Zero = ConstantInt::get(IdxType, 0); + for (unsigned i = 0; i < Count; i++) { + Value *Indices[2] = { + Zero, + ConstantInt::get(IdxType, i), + }; + auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), + AddrName); + auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); + auto EltAlign = commonAlignment(Align, SL->getElementOffset(i)); + llvm::Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); + NS->setAAMetadata(SI.getAAMetadata()); + } + + return true; + } + + if (auto *AT = dyn_cast<ArrayType>(T)) { + // If the array only have one element, we unpack. + auto NumElements = AT->getNumElements(); + if (NumElements == 1) { + V = IC.Builder.CreateExtractValue(V, 0); + combineStoreToNewValue(IC, SI, V); + return true; + } + + // Bail out if the array is too large. Ideally we would like to optimize + // arrays of arbitrary size but this has a terrible impact on compile time. + // The threshold here is chosen arbitrarily, maybe needs a little bit of + // tuning. + if (NumElements > IC.MaxArraySizeForCombine) + return false; + + const DataLayout &DL = IC.getDataLayout(); + auto EltSize = DL.getTypeAllocSize(AT->getElementType()); + const auto Align = SI.getAlign(); + + SmallString<16> EltName = V->getName(); + EltName += ".elt"; + auto *Addr = SI.getPointerOperand(); + SmallString<16> AddrName = Addr->getName(); + AddrName += ".repack"; + + auto *IdxType = Type::getInt64Ty(T->getContext()); + auto *Zero = ConstantInt::get(IdxType, 0); + + uint64_t Offset = 0; + for (uint64_t i = 0; i < NumElements; i++) { + Value *Indices[2] = { + Zero, + ConstantInt::get(IdxType, i), + }; + auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), + AddrName); + auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); + auto EltAlign = commonAlignment(Align, Offset); + Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); + NS->setAAMetadata(SI.getAAMetadata()); + Offset += EltSize; + } + + return true; + } + + return false; +} + +/// equivalentAddressValues - Test if A and B will obviously have the same +/// value. This includes recognizing that %t0 and %t1 will have the same +/// value in code like this: +/// %t0 = getelementptr \@a, 0, 3 +/// store i32 0, i32* %t0 +/// %t1 = getelementptr \@a, 0, 3 +/// %t2 = load i32* %t1 +/// +static bool equivalentAddressValues(Value *A, Value *B) { + // Test if the values are trivially equivalent. + if (A == B) return true; + + // Test if the values come form identical arithmetic instructions. + // This uses isIdenticalToWhenDefined instead of isIdenticalTo because + // its only used to compare two uses within the same basic block, which + // means that they'll always either have the same value or one of them + // will have an undefined value. + if (isa<BinaryOperator>(A) || + isa<CastInst>(A) || + isa<PHINode>(A) || + isa<GetElementPtrInst>(A)) + if (Instruction *BI = dyn_cast<Instruction>(B)) + if (cast<Instruction>(A)->isIdenticalToWhenDefined(BI)) + return true; + + // Otherwise they may not be equivalent. + return false; +} + +/// Converts store (bitcast (load (bitcast (select ...)))) to +/// store (load (select ...)), where select is minmax: +/// select ((cmp load V1, load V2), V1, V2). +static bool removeBitcastsFromLoadStoreOnMinMax(InstCombinerImpl &IC, + StoreInst &SI) { + // bitcast? + if (!match(SI.getPointerOperand(), m_BitCast(m_Value()))) + return false; + // load? integer? + Value *LoadAddr; + if (!match(SI.getValueOperand(), m_Load(m_BitCast(m_Value(LoadAddr))))) + return false; + auto *LI = cast<LoadInst>(SI.getValueOperand()); + if (!LI->getType()->isIntegerTy()) + return false; + Type *CmpLoadTy; + if (!isMinMaxWithLoads(LoadAddr, CmpLoadTy)) + return false; + + // Make sure the type would actually change. + // This condition can be hit with chains of bitcasts. + if (LI->getType() == CmpLoadTy) + return false; + + // Make sure we're not changing the size of the load/store. + const auto &DL = IC.getDataLayout(); + if (DL.getTypeStoreSizeInBits(LI->getType()) != + DL.getTypeStoreSizeInBits(CmpLoadTy)) + return false; + + if (!all_of(LI->users(), [LI, LoadAddr](User *U) { + auto *SI = dyn_cast<StoreInst>(U); + return SI && SI->getPointerOperand() != LI && + InstCombiner::peekThroughBitcast(SI->getPointerOperand()) != + LoadAddr && + !SI->getPointerOperand()->isSwiftError(); + })) + return false; + + IC.Builder.SetInsertPoint(LI); + LoadInst *NewLI = IC.combineLoadToNewType(*LI, CmpLoadTy); + // Replace all the stores with stores of the newly loaded value. + for (auto *UI : LI->users()) { + auto *USI = cast<StoreInst>(UI); + IC.Builder.SetInsertPoint(USI); + combineStoreToNewValue(IC, *USI, NewLI); + } + IC.replaceInstUsesWith(*LI, PoisonValue::get(LI->getType())); + IC.eraseInstFromFunction(*LI); + return true; +} + +Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { + Value *Val = SI.getOperand(0); + Value *Ptr = SI.getOperand(1); + + // Try to canonicalize the stored type. + if (combineStoreToValueType(*this, SI)) + return eraseInstFromFunction(SI); + + // Attempt to improve the alignment. + const Align KnownAlign = getOrEnforceKnownAlignment( + Ptr, DL.getPrefTypeAlign(Val->getType()), DL, &SI, &AC, &DT); + if (KnownAlign > SI.getAlign()) + SI.setAlignment(KnownAlign); + + // Try to canonicalize the stored type. + if (unpackStoreToAggregate(*this, SI)) + return eraseInstFromFunction(SI); + + if (removeBitcastsFromLoadStoreOnMinMax(*this, SI)) + return eraseInstFromFunction(SI); + + // Replace GEP indices if possible. + if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) { + Worklist.push(NewGEPI); + return &SI; + } + + // Don't hack volatile/ordered stores. + // FIXME: Some bits are legal for ordered atomic stores; needs refactoring. + if (!SI.isUnordered()) return nullptr; + + // If the RHS is an alloca with a single use, zapify the store, making the + // alloca dead. + if (Ptr->hasOneUse()) { + if (isa<AllocaInst>(Ptr)) + return eraseInstFromFunction(SI); + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr)) { + if (isa<AllocaInst>(GEP->getOperand(0))) { + if (GEP->getOperand(0)->hasOneUse()) + return eraseInstFromFunction(SI); + } + } + } + + // If we have a store to a location which is known constant, we can conclude + // that the store must be storing the constant value (else the memory + // wouldn't be constant), and this must be a noop. + if (AA->pointsToConstantMemory(Ptr)) + return eraseInstFromFunction(SI); + + // Do really simple DSE, to catch cases where there are several consecutive + // stores to the same location, separated by a few arithmetic operations. This + // situation often occurs with bitfield accesses. + BasicBlock::iterator BBI(SI); + for (unsigned ScanInsts = 6; BBI != SI.getParent()->begin() && ScanInsts; + --ScanInsts) { + --BBI; + // Don't count debug info directives, lest they affect codegen, + // and we skip pointer-to-pointer bitcasts, which are NOPs. + if (BBI->isDebugOrPseudoInst() || + (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy())) { + ScanInsts++; + continue; + } + + if (StoreInst *PrevSI = dyn_cast<StoreInst>(BBI)) { + // Prev store isn't volatile, and stores to the same location? + if (PrevSI->isUnordered() && + equivalentAddressValues(PrevSI->getOperand(1), SI.getOperand(1)) && + PrevSI->getValueOperand()->getType() == + SI.getValueOperand()->getType()) { + ++NumDeadStore; + // Manually add back the original store to the worklist now, so it will + // be processed after the operands of the removed store, as this may + // expose additional DSE opportunities. + Worklist.push(&SI); + eraseInstFromFunction(*PrevSI); + return nullptr; + } + break; + } + + // If this is a load, we have to stop. However, if the loaded value is from + // the pointer we're loading and is producing the pointer we're storing, + // then *this* store is dead (X = load P; store X -> P). + if (LoadInst *LI = dyn_cast<LoadInst>(BBI)) { + if (LI == Val && equivalentAddressValues(LI->getOperand(0), Ptr)) { + assert(SI.isUnordered() && "can't eliminate ordering operation"); + return eraseInstFromFunction(SI); + } + + // Otherwise, this is a load from some other location. Stores before it + // may not be dead. + break; + } + + // Don't skip over loads, throws or things that can modify memory. + if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory() || BBI->mayThrow()) + break; + } + + // store X, null -> turns into 'unreachable' in SimplifyCFG + // store X, GEP(null, Y) -> turns into 'unreachable' in SimplifyCFG + if (canSimplifyNullStoreOrGEP(SI)) { + if (!isa<PoisonValue>(Val)) + return replaceOperand(SI, 0, PoisonValue::get(Val->getType())); + return nullptr; // Do not modify these! + } + + // store undef, Ptr -> noop + // FIXME: This is technically incorrect because it might overwrite a poison + // value. Change to PoisonValue once #52930 is resolved. + if (isa<UndefValue>(Val)) + return eraseInstFromFunction(SI); + + return nullptr; +} + +/// Try to transform: +/// if () { *P = v1; } else { *P = v2 } +/// or: +/// *P = v1; if () { *P = v2; } +/// into a phi node with a store in the successor. +bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { + if (!SI.isUnordered()) + return false; // This code has not been audited for volatile/ordered case. + + // Check if the successor block has exactly 2 incoming edges. + BasicBlock *StoreBB = SI.getParent(); + BasicBlock *DestBB = StoreBB->getTerminator()->getSuccessor(0); + if (!DestBB->hasNPredecessors(2)) + return false; + + // Capture the other block (the block that doesn't contain our store). + pred_iterator PredIter = pred_begin(DestBB); + if (*PredIter == StoreBB) + ++PredIter; + BasicBlock *OtherBB = *PredIter; + + // Bail out if all of the relevant blocks aren't distinct. This can happen, + // for example, if SI is in an infinite loop. + if (StoreBB == DestBB || OtherBB == DestBB) + return false; + + // Verify that the other block ends in a branch and is not otherwise empty. + BasicBlock::iterator BBI(OtherBB->getTerminator()); + BranchInst *OtherBr = dyn_cast<BranchInst>(BBI); + if (!OtherBr || BBI == OtherBB->begin()) + return false; + + // If the other block ends in an unconditional branch, check for the 'if then + // else' case. There is an instruction before the branch. + StoreInst *OtherStore = nullptr; + if (OtherBr->isUnconditional()) { + --BBI; + // Skip over debugging info and pseudo probes. + while (BBI->isDebugOrPseudoInst() || + (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy())) { + if (BBI==OtherBB->begin()) + return false; + --BBI; + } + // If this isn't a store, isn't a store to the same location, or is not the + // right kind of store, bail out. + OtherStore = dyn_cast<StoreInst>(BBI); + if (!OtherStore || OtherStore->getOperand(1) != SI.getOperand(1) || + !SI.isSameOperationAs(OtherStore)) + return false; + } else { + // Otherwise, the other block ended with a conditional branch. If one of the + // destinations is StoreBB, then we have the if/then case. + if (OtherBr->getSuccessor(0) != StoreBB && + OtherBr->getSuccessor(1) != StoreBB) + return false; + + // Okay, we know that OtherBr now goes to Dest and StoreBB, so this is an + // if/then triangle. See if there is a store to the same ptr as SI that + // lives in OtherBB. + for (;; --BBI) { + // Check to see if we find the matching store. + if ((OtherStore = dyn_cast<StoreInst>(BBI))) { + if (OtherStore->getOperand(1) != SI.getOperand(1) || + !SI.isSameOperationAs(OtherStore)) + return false; + break; + } + // If we find something that may be using or overwriting the stored + // value, or if we run out of instructions, we can't do the transform. + if (BBI->mayReadFromMemory() || BBI->mayThrow() || + BBI->mayWriteToMemory() || BBI == OtherBB->begin()) + return false; + } + + // In order to eliminate the store in OtherBr, we have to make sure nothing + // reads or overwrites the stored value in StoreBB. + for (BasicBlock::iterator I = StoreBB->begin(); &*I != &SI; ++I) { + // FIXME: This should really be AA driven. + if (I->mayReadFromMemory() || I->mayThrow() || I->mayWriteToMemory()) + return false; + } + } + + // Insert a PHI node now if we need it. + Value *MergedVal = OtherStore->getOperand(0); + // The debug locations of the original instructions might differ. Merge them. + DebugLoc MergedLoc = DILocation::getMergedLocation(SI.getDebugLoc(), + OtherStore->getDebugLoc()); + if (MergedVal != SI.getOperand(0)) { + PHINode *PN = PHINode::Create(MergedVal->getType(), 2, "storemerge"); + PN->addIncoming(SI.getOperand(0), SI.getParent()); + PN->addIncoming(OtherStore->getOperand(0), OtherBB); + MergedVal = InsertNewInstBefore(PN, DestBB->front()); + PN->setDebugLoc(MergedLoc); + } + + // Advance to a place where it is safe to insert the new store and insert it. + BBI = DestBB->getFirstInsertionPt(); + StoreInst *NewSI = + new StoreInst(MergedVal, SI.getOperand(1), SI.isVolatile(), SI.getAlign(), + SI.getOrdering(), SI.getSyncScopeID()); + InsertNewInstBefore(NewSI, *BBI); + NewSI->setDebugLoc(MergedLoc); + + // If the two stores had AA tags, merge them. + AAMDNodes AATags = SI.getAAMetadata(); + if (AATags) + NewSI->setAAMetadata(AATags.merge(OtherStore->getAAMetadata())); + + // Nuke the old stores. + eraseInstFromFunction(SI); + eraseInstFromFunction(*OtherStore); + return true; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp new file mode 100644 index 000000000000..2a34edbf6cb8 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -0,0 +1,1624 @@ +//===- InstCombineMulDivRem.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the visit functions for mul, fmul, sdiv, udiv, fdiv, +// srem, urem, frem. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" +#include <cassert> + +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + +using namespace llvm; +using namespace PatternMatch; + +/// The specific integer value is used in a context where it is known to be +/// non-zero. If this allows us to simplify the computation, do so and return +/// the new operand, otherwise return null. +static Value *simplifyValueKnownNonZero(Value *V, InstCombinerImpl &IC, + Instruction &CxtI) { + // If V has multiple uses, then we would have to do more analysis to determine + // if this is safe. For example, the use could be in dynamically unreached + // code. + if (!V->hasOneUse()) return nullptr; + + bool MadeChange = false; + + // ((1 << A) >>u B) --> (1 << (A-B)) + // Because V cannot be zero, we know that B is less than A. + Value *A = nullptr, *B = nullptr, *One = nullptr; + if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(One), m_Value(A))), m_Value(B))) && + match(One, m_One())) { + A = IC.Builder.CreateSub(A, B); + return IC.Builder.CreateShl(One, A); + } + + // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it + // inexact. Similarly for <<. + BinaryOperator *I = dyn_cast<BinaryOperator>(V); + if (I && I->isLogicalShift() && + IC.isKnownToBeAPowerOfTwo(I->getOperand(0), false, 0, &CxtI)) { + // We know that this is an exact/nuw shift and that the input is a + // non-zero context as well. + if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { + IC.replaceOperand(*I, 0, V2); + MadeChange = true; + } + + if (I->getOpcode() == Instruction::LShr && !I->isExact()) { + I->setIsExact(); + MadeChange = true; + } + + if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) { + I->setHasNoUnsignedWrap(); + MadeChange = true; + } + } + + // TODO: Lots more we could do here: + // If V is a phi node, we can call this on each of its operands. + // "select cond, X, 0" can simplify to "X". + + return MadeChange ? V : nullptr; +} + +// TODO: This is a specific form of a much more general pattern. +// We could detect a select with any binop identity constant, or we +// could use SimplifyBinOp to see if either arm of the select reduces. +// But that needs to be done carefully and/or while removing potential +// reverse canonicalizations as in InstCombiner::foldSelectIntoOp(). +static Value *foldMulSelectToNegate(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *Cond, *OtherOp; + + // mul (select Cond, 1, -1), OtherOp --> select Cond, OtherOp, -OtherOp + // mul OtherOp, (select Cond, 1, -1) --> select Cond, OtherOp, -OtherOp + if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_One(), m_AllOnes())), + m_Value(OtherOp)))) { + bool HasAnyNoWrap = I.hasNoSignedWrap() || I.hasNoUnsignedWrap(); + Value *Neg = Builder.CreateNeg(OtherOp, "", false, HasAnyNoWrap); + return Builder.CreateSelect(Cond, OtherOp, Neg); + } + // mul (select Cond, -1, 1), OtherOp --> select Cond, -OtherOp, OtherOp + // mul OtherOp, (select Cond, -1, 1) --> select Cond, -OtherOp, OtherOp + if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_AllOnes(), m_One())), + m_Value(OtherOp)))) { + bool HasAnyNoWrap = I.hasNoSignedWrap() || I.hasNoUnsignedWrap(); + Value *Neg = Builder.CreateNeg(OtherOp, "", false, HasAnyNoWrap); + return Builder.CreateSelect(Cond, Neg, OtherOp); + } + + // fmul (select Cond, 1.0, -1.0), OtherOp --> select Cond, OtherOp, -OtherOp + // fmul OtherOp, (select Cond, 1.0, -1.0) --> select Cond, OtherOp, -OtherOp + if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(1.0), + m_SpecificFP(-1.0))), + m_Value(OtherOp)))) { + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + return Builder.CreateSelect(Cond, OtherOp, Builder.CreateFNeg(OtherOp)); + } + + // fmul (select Cond, -1.0, 1.0), OtherOp --> select Cond, -OtherOp, OtherOp + // fmul OtherOp, (select Cond, -1.0, 1.0) --> select Cond, -OtherOp, OtherOp + if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(-1.0), + m_SpecificFP(1.0))), + m_Value(OtherOp)))) { + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + return Builder.CreateSelect(Cond, Builder.CreateFNeg(OtherOp), OtherOp); + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { + if (Value *V = simplifyMulInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return replaceInstUsesWith(I, V); + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + unsigned BitWidth = I.getType()->getScalarSizeInBits(); + + // X * -1 == 0 - X + if (match(Op1, m_AllOnes())) { + BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName()); + if (I.hasNoSignedWrap()) + BO->setHasNoSignedWrap(); + return BO; + } + + // Also allow combining multiply instructions on vectors. + { + Value *NewOp; + Constant *C1, *C2; + const APInt *IVal; + if (match(&I, m_Mul(m_Shl(m_Value(NewOp), m_Constant(C2)), + m_Constant(C1))) && + match(C1, m_APInt(IVal))) { + // ((X << C2)*C1) == (X * (C1 << C2)) + Constant *Shl = ConstantExpr::getShl(C1, C2); + BinaryOperator *Mul = cast<BinaryOperator>(I.getOperand(0)); + BinaryOperator *BO = BinaryOperator::CreateMul(NewOp, Shl); + if (I.hasNoUnsignedWrap() && Mul->hasNoUnsignedWrap()) + BO->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && Mul->hasNoSignedWrap() && + Shl->isNotMinSignedValue()) + BO->setHasNoSignedWrap(); + return BO; + } + + if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) { + // Replace X*(2^C) with X << C, where C is either a scalar or a vector. + if (Constant *NewCst = ConstantExpr::getExactLogBase2(C1)) { + BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst); + + if (I.hasNoUnsignedWrap()) + Shl->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap()) { + const APInt *V; + if (match(NewCst, m_APInt(V)) && *V != V->getBitWidth() - 1) + Shl->setHasNoSignedWrap(); + } + + return Shl; + } + } + } + + if (Op0->hasOneUse() && match(Op1, m_NegatedPower2())) { + // Interpret X * (-1<<C) as (-X) * (1<<C) and try to sink the negation. + // The "* (1<<C)" thus becomes a potential shifting opportunity. + if (Value *NegOp0 = Negator::Negate(/*IsNegation*/ true, Op0, *this)) + return BinaryOperator::CreateMul( + NegOp0, ConstantExpr::getNeg(cast<Constant>(Op1)), I.getName()); + } + + if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) + return FoldedMul; + + if (Value *FoldedMul = foldMulSelectToNegate(I, Builder)) + return replaceInstUsesWith(I, FoldedMul); + + // Simplify mul instructions with a constant RHS. + if (isa<Constant>(Op1)) { + // Canonicalize (X+C1)*CI -> X*CI+C1*CI. + Value *X; + Constant *C1; + if (match(Op0, m_OneUse(m_Add(m_Value(X), m_Constant(C1))))) { + Value *Mul = Builder.CreateMul(C1, Op1); + // Only go forward with the transform if C1*CI simplifies to a tidier + // constant. + if (!match(Mul, m_Mul(m_Value(), m_Value()))) + return BinaryOperator::CreateAdd(Builder.CreateMul(X, Op1), Mul); + } + } + + // abs(X) * abs(X) -> X * X + // nabs(X) * nabs(X) -> X * X + if (Op0 == Op1) { + Value *X, *Y; + SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor; + if (SPF == SPF_ABS || SPF == SPF_NABS) + return BinaryOperator::CreateMul(X, X); + + if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) + return BinaryOperator::CreateMul(X, X); + } + + // -X * C --> X * -C + Value *X, *Y; + Constant *Op1C; + if (match(Op0, m_Neg(m_Value(X))) && match(Op1, m_Constant(Op1C))) + return BinaryOperator::CreateMul(X, ConstantExpr::getNeg(Op1C)); + + // -X * -Y --> X * Y + if (match(Op0, m_Neg(m_Value(X))) && match(Op1, m_Neg(m_Value(Y)))) { + auto *NewMul = BinaryOperator::CreateMul(X, Y); + if (I.hasNoSignedWrap() && + cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap() && + cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap()) + NewMul->setHasNoSignedWrap(); + return NewMul; + } + + // -X * Y --> -(X * Y) + // X * -Y --> -(X * Y) + if (match(&I, m_c_Mul(m_OneUse(m_Neg(m_Value(X))), m_Value(Y)))) + return BinaryOperator::CreateNeg(Builder.CreateMul(X, Y)); + + // (X / Y) * Y = X - (X % Y) + // (X / Y) * -Y = (X % Y) - X + { + Value *Y = Op1; + BinaryOperator *Div = dyn_cast<BinaryOperator>(Op0); + if (!Div || (Div->getOpcode() != Instruction::UDiv && + Div->getOpcode() != Instruction::SDiv)) { + Y = Op0; + Div = dyn_cast<BinaryOperator>(Op1); + } + Value *Neg = dyn_castNegVal(Y); + if (Div && Div->hasOneUse() && + (Div->getOperand(1) == Y || Div->getOperand(1) == Neg) && + (Div->getOpcode() == Instruction::UDiv || + Div->getOpcode() == Instruction::SDiv)) { + Value *X = Div->getOperand(0), *DivOp1 = Div->getOperand(1); + + // If the division is exact, X % Y is zero, so we end up with X or -X. + if (Div->isExact()) { + if (DivOp1 == Y) + return replaceInstUsesWith(I, X); + return BinaryOperator::CreateNeg(X); + } + + auto RemOpc = Div->getOpcode() == Instruction::UDiv ? Instruction::URem + : Instruction::SRem; + // X must be frozen because we are increasing its number of uses. + Value *XFreeze = Builder.CreateFreeze(X, X->getName() + ".fr"); + Value *Rem = Builder.CreateBinOp(RemOpc, XFreeze, DivOp1); + if (DivOp1 == Y) + return BinaryOperator::CreateSub(XFreeze, Rem); + return BinaryOperator::CreateSub(Rem, XFreeze); + } + } + + // Fold the following two scenarios: + // 1) i1 mul -> i1 and. + // 2) X * Y --> X & Y, iff X, Y can be only {0,1}. + // Note: We could use known bits to generalize this and related patterns with + // shifts/truncs + Type *Ty = I.getType(); + if (Ty->isIntOrIntVectorTy(1) || + (match(Op0, m_And(m_Value(), m_One())) && + match(Op1, m_And(m_Value(), m_One())))) + return BinaryOperator::CreateAnd(Op0, Op1); + + // X*(1 << Y) --> X << Y + // (1 << Y)*X --> X << Y + { + Value *Y; + BinaryOperator *BO = nullptr; + bool ShlNSW = false; + if (match(Op0, m_Shl(m_One(), m_Value(Y)))) { + BO = BinaryOperator::CreateShl(Op1, Y); + ShlNSW = cast<ShlOperator>(Op0)->hasNoSignedWrap(); + } else if (match(Op1, m_Shl(m_One(), m_Value(Y)))) { + BO = BinaryOperator::CreateShl(Op0, Y); + ShlNSW = cast<ShlOperator>(Op1)->hasNoSignedWrap(); + } + if (BO) { + if (I.hasNoUnsignedWrap()) + BO->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && ShlNSW) + BO->setHasNoSignedWrap(); + return BO; + } + } + + // (zext bool X) * (zext bool Y) --> zext (and X, Y) + // (sext bool X) * (sext bool Y) --> zext (and X, Y) + // Note: -1 * -1 == 1 * 1 == 1 (if the extends match, the result is the same) + if (((match(Op0, m_ZExt(m_Value(X))) && match(Op1, m_ZExt(m_Value(Y)))) || + (match(Op0, m_SExt(m_Value(X))) && match(Op1, m_SExt(m_Value(Y))))) && + X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() && + (Op0->hasOneUse() || Op1->hasOneUse() || X == Y)) { + Value *And = Builder.CreateAnd(X, Y, "mulbool"); + return CastInst::Create(Instruction::ZExt, And, Ty); + } + // (sext bool X) * (zext bool Y) --> sext (and X, Y) + // (zext bool X) * (sext bool Y) --> sext (and X, Y) + // Note: -1 * 1 == 1 * -1 == -1 + if (((match(Op0, m_SExt(m_Value(X))) && match(Op1, m_ZExt(m_Value(Y)))) || + (match(Op0, m_ZExt(m_Value(X))) && match(Op1, m_SExt(m_Value(Y))))) && + X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() && + (Op0->hasOneUse() || Op1->hasOneUse())) { + Value *And = Builder.CreateAnd(X, Y, "mulbool"); + return CastInst::Create(Instruction::SExt, And, Ty); + } + + // (zext bool X) * Y --> X ? Y : 0 + // Y * (zext bool X) --> X ? Y : 0 + if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(X, Op1, ConstantInt::getNullValue(Ty)); + if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(X, Op0, ConstantInt::getNullValue(Ty)); + + Constant *ImmC; + if (match(Op1, m_ImmConstant(ImmC))) { + // (sext bool X) * C --> X ? -C : 0 + if (match(Op0, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + Constant *NegC = ConstantExpr::getNeg(ImmC); + return SelectInst::Create(X, NegC, ConstantInt::getNullValue(Ty)); + } + + // (ashr i32 X, 31) * C --> (X < 0) ? -C : 0 + const APInt *C; + if (match(Op0, m_OneUse(m_AShr(m_Value(X), m_APInt(C)))) && + *C == C->getBitWidth() - 1) { + Constant *NegC = ConstantExpr::getNeg(ImmC); + Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); + return SelectInst::Create(IsNeg, NegC, ConstantInt::getNullValue(Ty)); + } + } + + // (lshr X, 31) * Y --> (X < 0) ? Y : 0 + // TODO: We are not checking one-use because the elimination of the multiply + // is better for analysis? + const APInt *C; + if (match(&I, m_c_BinOp(m_LShr(m_Value(X), m_APInt(C)), m_Value(Y))) && + *C == C->getBitWidth() - 1) { + Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); + return SelectInst::Create(IsNeg, Y, ConstantInt::getNullValue(Ty)); + } + + // (and X, 1) * Y --> (trunc X) ? Y : 0 + if (match(&I, m_c_BinOp(m_OneUse(m_And(m_Value(X), m_One())), m_Value(Y)))) { + Value *Tr = Builder.CreateTrunc(X, CmpInst::makeCmpResultType(Ty)); + return SelectInst::Create(Tr, Y, ConstantInt::getNullValue(Ty)); + } + + // ((ashr X, 31) | 1) * X --> abs(X) + // X * ((ashr X, 31) | 1) --> abs(X) + if (match(&I, m_c_BinOp(m_Or(m_AShr(m_Value(X), + m_SpecificIntAllowUndef(BitWidth - 1)), + m_One()), + m_Deferred(X)))) { + Value *Abs = Builder.CreateBinaryIntrinsic( + Intrinsic::abs, X, + ConstantInt::getBool(I.getContext(), I.hasNoSignedWrap())); + Abs->takeName(&I); + return replaceInstUsesWith(I, Abs); + } + + if (Instruction *Ext = narrowMathIfNoOverflow(I)) + return Ext; + + bool Changed = false; + if (!I.hasNoSignedWrap() && willNotOverflowSignedMul(Op0, Op1, I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + + if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedMul(Op0, Op1, I)) { + Changed = true; + I.setHasNoUnsignedWrap(true); + } + + return Changed ? &I : nullptr; +} + +Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) { + BinaryOperator::BinaryOps Opcode = I.getOpcode(); + assert((Opcode == Instruction::FMul || Opcode == Instruction::FDiv) && + "Expected fmul or fdiv"); + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *X, *Y; + + // -X * -Y --> X * Y + // -X / -Y --> X / Y + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) + return BinaryOperator::CreateWithCopiedFlags(Opcode, X, Y, &I); + + // fabs(X) * fabs(X) -> X * X + // fabs(X) / fabs(X) -> X / X + if (Op0 == Op1 && match(Op0, m_FAbs(m_Value(X)))) + return BinaryOperator::CreateWithCopiedFlags(Opcode, X, X, &I); + + // fabs(X) * fabs(Y) --> fabs(X * Y) + // fabs(X) / fabs(Y) --> fabs(X / Y) + if (match(Op0, m_FAbs(m_Value(X))) && match(Op1, m_FAbs(m_Value(Y))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + Value *XY = Builder.CreateBinOp(Opcode, X, Y); + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, XY); + Fabs->takeName(&I); + return replaceInstUsesWith(I, Fabs); + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { + if (Value *V = simplifyFMulInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) + return FoldedMul; + + if (Value *FoldedMul = foldMulSelectToNegate(I, Builder)) + return replaceInstUsesWith(I, FoldedMul); + + if (Instruction *R = foldFPSignBitOps(I)) + return R; + + // X * -1.0 --> -X + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (match(Op1, m_SpecificFP(-1.0))) + return UnaryOperator::CreateFNegFMF(Op0, &I); + + // -X * C --> X * -C + Value *X, *Y; + Constant *C; + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C))) + return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + + // (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E) + if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) + return replaceInstUsesWith(I, V); + + if (I.hasAllowReassoc()) { + // Reassociate constant RHS with another constant to form constant + // expression. + if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) { + Constant *C1; + if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) { + // (C1 / X) * C --> (C * C1) / X + Constant *CC1 = ConstantExpr::getFMul(C, C1); + if (CC1->isNormalFP()) + return BinaryOperator::CreateFDivFMF(CC1, X, &I); + } + if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) { + // (X / C1) * C --> X * (C / C1) + Constant *CDivC1 = ConstantExpr::getFDiv(C, C1); + if (CDivC1->isNormalFP()) + return BinaryOperator::CreateFMulFMF(X, CDivC1, &I); + + // If the constant was a denormal, try reassociating differently. + // (X / C1) * C --> X / (C1 / C) + Constant *C1DivC = ConstantExpr::getFDiv(C1, C); + if (Op0->hasOneUse() && C1DivC->isNormalFP()) + return BinaryOperator::CreateFDivFMF(X, C1DivC, &I); + } + + // We do not need to match 'fadd C, X' and 'fsub X, C' because they are + // canonicalized to 'fadd X, C'. Distributing the multiply may allow + // further folds and (X * C) + C2 is 'fma'. + if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) { + // (X + C1) * C --> (X * C) + (C * C1) + Constant *CC1 = ConstantExpr::getFMul(C, C1); + Value *XC = Builder.CreateFMulFMF(X, C, &I); + return BinaryOperator::CreateFAddFMF(XC, CC1, &I); + } + if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) { + // (C1 - X) * C --> (C * C1) - (X * C) + Constant *CC1 = ConstantExpr::getFMul(C, C1); + Value *XC = Builder.CreateFMulFMF(X, C, &I); + return BinaryOperator::CreateFSubFMF(CC1, XC, &I); + } + } + + Value *Z; + if (match(&I, m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))), + m_Value(Z)))) { + // Sink division: (X / Y) * Z --> (X * Z) / Y + Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I); + return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I); + } + + // sqrt(X) * sqrt(Y) -> sqrt(X * Y) + // nnan disallows the possibility of returning a number if both operands are + // negative (in that case, we should return NaN). + if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) && + match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) { + Value *XY = Builder.CreateFMulFMF(X, Y, &I); + Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I); + return replaceInstUsesWith(I, Sqrt); + } + + // The following transforms are done irrespective of the number of uses + // for the expression "1.0/sqrt(X)". + // 1) 1.0/sqrt(X) * X -> X/sqrt(X) + // 2) X * 1.0/sqrt(X) -> X/sqrt(X) + // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it + // has the necessary (reassoc) fast-math-flags. + if (I.hasNoSignedZeros() && + match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && + match(Y, m_Sqrt(m_Value(X))) && Op1 == X) + return BinaryOperator::CreateFDivFMF(X, Y, &I); + if (I.hasNoSignedZeros() && + match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && + match(Y, m_Sqrt(m_Value(X))) && Op0 == X) + return BinaryOperator::CreateFDivFMF(X, Y, &I); + + // Like the similar transform in instsimplify, this requires 'nsz' because + // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0. + if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && + Op0->hasNUses(2)) { + // Peek through fdiv to find squaring of square root: + // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y + if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) { + Value *XX = Builder.CreateFMulFMF(X, X, &I); + return BinaryOperator::CreateFDivFMF(XX, Y, &I); + } + // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X) + if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) { + Value *XX = Builder.CreateFMulFMF(X, X, &I); + return BinaryOperator::CreateFDivFMF(Y, XX, &I); + } + } + + if (I.isOnlyUserOfAnyOperand()) { + // pow(x, y) * pow(x, z) -> pow(x, y + z) + if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && + match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) { + auto *YZ = Builder.CreateFAddFMF(Y, Z, &I); + auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I); + return replaceInstUsesWith(I, NewPow); + } + + // powi(x, y) * powi(x, z) -> powi(x, y + z) + if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) && + match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) && + Y->getType() == Z->getType()) { + auto *YZ = Builder.CreateAdd(Y, Z); + auto *NewPow = Builder.CreateIntrinsic( + Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I); + return replaceInstUsesWith(I, NewPow); + } + + // exp(X) * exp(Y) -> exp(X + Y) + if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) { + Value *XY = Builder.CreateFAddFMF(X, Y, &I); + Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I); + return replaceInstUsesWith(I, Exp); + } + + // exp2(X) * exp2(Y) -> exp2(X + Y) + if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) { + Value *XY = Builder.CreateFAddFMF(X, Y, &I); + Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I); + return replaceInstUsesWith(I, Exp2); + } + } + + // (X*Y) * X => (X*X) * Y where Y != X + // The purpose is two-fold: + // 1) to form a power expression (of X). + // 2) potentially shorten the critical path: After transformation, the + // latency of the instruction Y is amortized by the expression of X*X, + // and therefore Y is in a "less critical" position compared to what it + // was before the transformation. + if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) && + Op1 != Y) { + Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I); + return BinaryOperator::CreateFMulFMF(XX, Y, &I); + } + if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) && + Op0 != Y) { + Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I); + return BinaryOperator::CreateFMulFMF(XX, Y, &I); + } + } + + // log2(X * 0.5) * Y = log2(X) * Y - Y + if (I.isFast()) { + IntrinsicInst *Log2 = nullptr; + if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::log2>( + m_OneUse(m_FMul(m_Value(X), m_SpecificFP(0.5))))))) { + Log2 = cast<IntrinsicInst>(Op0); + Y = Op1; + } + if (match(Op1, m_OneUse(m_Intrinsic<Intrinsic::log2>( + m_OneUse(m_FMul(m_Value(X), m_SpecificFP(0.5))))))) { + Log2 = cast<IntrinsicInst>(Op1); + Y = Op0; + } + if (Log2) { + Value *Log2 = Builder.CreateUnaryIntrinsic(Intrinsic::log2, X, &I); + Value *LogXTimesY = Builder.CreateFMulFMF(Log2, Y, &I); + return BinaryOperator::CreateFSubFMF(LogXTimesY, Y, &I); + } + } + + return nullptr; +} + +/// Fold a divide or remainder with a select instruction divisor when one of the +/// select operands is zero. In that case, we can use the other select operand +/// because div/rem by zero is undefined. +bool InstCombinerImpl::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { + SelectInst *SI = dyn_cast<SelectInst>(I.getOperand(1)); + if (!SI) + return false; + + int NonNullOperand; + if (match(SI->getTrueValue(), m_Zero())) + // div/rem X, (Cond ? 0 : Y) -> div/rem X, Y + NonNullOperand = 2; + else if (match(SI->getFalseValue(), m_Zero())) + // div/rem X, (Cond ? Y : 0) -> div/rem X, Y + NonNullOperand = 1; + else + return false; + + // Change the div/rem to use 'Y' instead of the select. + replaceOperand(I, 1, SI->getOperand(NonNullOperand)); + + // Okay, we know we replace the operand of the div/rem with 'Y' with no + // problem. However, the select, or the condition of the select may have + // multiple uses. Based on our knowledge that the operand must be non-zero, + // propagate the known value for the select into other uses of it, and + // propagate a known value of the condition into its other users. + + // If the select and condition only have a single use, don't bother with this, + // early exit. + Value *SelectCond = SI->getCondition(); + if (SI->use_empty() && SelectCond->hasOneUse()) + return true; + + // Scan the current block backward, looking for other uses of SI. + BasicBlock::iterator BBI = I.getIterator(), BBFront = I.getParent()->begin(); + Type *CondTy = SelectCond->getType(); + while (BBI != BBFront) { + --BBI; + // If we found an instruction that we can't assume will return, so + // information from below it cannot be propagated above it. + if (!isGuaranteedToTransferExecutionToSuccessor(&*BBI)) + break; + + // Replace uses of the select or its condition with the known values. + for (Use &Op : BBI->operands()) { + if (Op == SI) { + replaceUse(Op, SI->getOperand(NonNullOperand)); + Worklist.push(&*BBI); + } else if (Op == SelectCond) { + replaceUse(Op, NonNullOperand == 1 ? ConstantInt::getTrue(CondTy) + : ConstantInt::getFalse(CondTy)); + Worklist.push(&*BBI); + } + } + + // If we past the instruction, quit looking for it. + if (&*BBI == SI) + SI = nullptr; + if (&*BBI == SelectCond) + SelectCond = nullptr; + + // If we ran out of things to eliminate, break out of the loop. + if (!SelectCond && !SI) + break; + + } + return true; +} + +/// True if the multiply can not be expressed in an int this size. +static bool multiplyOverflows(const APInt &C1, const APInt &C2, APInt &Product, + bool IsSigned) { + bool Overflow; + Product = IsSigned ? C1.smul_ov(C2, Overflow) : C1.umul_ov(C2, Overflow); + return Overflow; +} + +/// True if C1 is a multiple of C2. Quotient contains C1/C2. +static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, + bool IsSigned) { + assert(C1.getBitWidth() == C2.getBitWidth() && "Constant widths not equal"); + + // Bail if we will divide by zero. + if (C2.isZero()) + return false; + + // Bail if we would divide INT_MIN by -1. + if (IsSigned && C1.isMinSignedValue() && C2.isAllOnes()) + return false; + + APInt Remainder(C1.getBitWidth(), /*val=*/0ULL, IsSigned); + if (IsSigned) + APInt::sdivrem(C1, C2, Quotient, Remainder); + else + APInt::udivrem(C1, C2, Quotient, Remainder); + + return Remainder.isMinValue(); +} + +/// This function implements the transforms common to both integer division +/// instructions (udiv and sdiv). It is called by the visitors to those integer +/// division instructions. +/// Common integer divide transforms +Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + bool IsSigned = I.getOpcode() == Instruction::SDiv; + Type *Ty = I.getType(); + + // The RHS is known non-zero. + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) + return replaceOperand(I, 1, V); + + // Handle cases involving: [su]div X, (select Cond, Y, Z) + // This does not apply for fdiv. + if (simplifyDivRemOfSelectWithZeroOp(I)) + return &I; + + // If the divisor is a select-of-constants, try to constant fold all div ops: + // C / (select Cond, TrueC, FalseC) --> select Cond, (C / TrueC), (C / FalseC) + // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds. + if (match(Op0, m_ImmConstant()) && + match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) { + if (Instruction *R = FoldOpIntoSelect(I, cast<SelectInst>(Op1), + /*FoldWithMultiUse*/ true)) + return R; + } + + const APInt *C2; + if (match(Op1, m_APInt(C2))) { + Value *X; + const APInt *C1; + + // (X / C1) / C2 -> X / (C1*C2) + if ((IsSigned && match(Op0, m_SDiv(m_Value(X), m_APInt(C1)))) || + (!IsSigned && match(Op0, m_UDiv(m_Value(X), m_APInt(C1))))) { + APInt Product(C1->getBitWidth(), /*val=*/0ULL, IsSigned); + if (!multiplyOverflows(*C1, *C2, Product, IsSigned)) + return BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(Ty, Product)); + } + + if ((IsSigned && match(Op0, m_NSWMul(m_Value(X), m_APInt(C1)))) || + (!IsSigned && match(Op0, m_NUWMul(m_Value(X), m_APInt(C1))))) { + APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned); + + // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1. + if (isMultiple(*C2, *C1, Quotient, IsSigned)) { + auto *NewDiv = BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(Ty, Quotient)); + NewDiv->setIsExact(I.isExact()); + return NewDiv; + } + + // (X * C1) / C2 -> X * (C1 / C2) if C1 is a multiple of C2. + if (isMultiple(*C1, *C2, Quotient, IsSigned)) { + auto *Mul = BinaryOperator::Create(Instruction::Mul, X, + ConstantInt::get(Ty, Quotient)); + auto *OBO = cast<OverflowingBinaryOperator>(Op0); + Mul->setHasNoUnsignedWrap(!IsSigned && OBO->hasNoUnsignedWrap()); + Mul->setHasNoSignedWrap(OBO->hasNoSignedWrap()); + return Mul; + } + } + + if ((IsSigned && match(Op0, m_NSWShl(m_Value(X), m_APInt(C1))) && + C1->ult(C1->getBitWidth() - 1)) || + (!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))) && + C1->ult(C1->getBitWidth()))) { + APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned); + APInt C1Shifted = APInt::getOneBitSet( + C1->getBitWidth(), static_cast<unsigned>(C1->getZExtValue())); + + // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of 1 << C1. + if (isMultiple(*C2, C1Shifted, Quotient, IsSigned)) { + auto *BO = BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(Ty, Quotient)); + BO->setIsExact(I.isExact()); + return BO; + } + + // (X << C1) / C2 -> X * ((1 << C1) / C2) if 1 << C1 is a multiple of C2. + if (isMultiple(C1Shifted, *C2, Quotient, IsSigned)) { + auto *Mul = BinaryOperator::Create(Instruction::Mul, X, + ConstantInt::get(Ty, Quotient)); + auto *OBO = cast<OverflowingBinaryOperator>(Op0); + Mul->setHasNoUnsignedWrap(!IsSigned && OBO->hasNoUnsignedWrap()); + Mul->setHasNoSignedWrap(OBO->hasNoSignedWrap()); + return Mul; + } + } + + if (!C2->isZero()) // avoid X udiv 0 + if (Instruction *FoldedDiv = foldBinOpIntoSelectOrPhi(I)) + return FoldedDiv; + } + + if (match(Op0, m_One())) { + assert(!Ty->isIntOrIntVectorTy(1) && "i1 divide not removed?"); + if (IsSigned) { + // 1 / 0 --> undef ; 1 / 1 --> 1 ; 1 / -1 --> -1 ; 1 / anything else --> 0 + // (Op1 + 1) u< 3 ? Op1 : 0 + // Op1 must be frozen because we are increasing its number of uses. + Value *F1 = Builder.CreateFreeze(Op1, Op1->getName() + ".fr"); + Value *Inc = Builder.CreateAdd(F1, Op0); + Value *Cmp = Builder.CreateICmpULT(Inc, ConstantInt::get(Ty, 3)); + return SelectInst::Create(Cmp, F1, ConstantInt::get(Ty, 0)); + } else { + // If Op1 is 0 then it's undefined behaviour. If Op1 is 1 then the + // result is one, otherwise it's zero. + return new ZExtInst(Builder.CreateICmpEQ(Op1, Op0), Ty); + } + } + + // See if we can fold away this div instruction. + if (SimplifyDemandedInstructionBits(I)) + return &I; + + // (X - (X rem Y)) / Y -> X / Y; usually originates as ((X / Y) * Y) / Y + Value *X, *Z; + if (match(Op0, m_Sub(m_Value(X), m_Value(Z)))) // (X - Z) / Y; Y = Op1 + if ((IsSigned && match(Z, m_SRem(m_Specific(X), m_Specific(Op1)))) || + (!IsSigned && match(Z, m_URem(m_Specific(X), m_Specific(Op1))))) + return BinaryOperator::Create(I.getOpcode(), X, Op1); + + // (X << Y) / X -> 1 << Y + Value *Y; + if (IsSigned && match(Op0, m_NSWShl(m_Specific(Op1), m_Value(Y)))) + return BinaryOperator::CreateNSWShl(ConstantInt::get(Ty, 1), Y); + if (!IsSigned && match(Op0, m_NUWShl(m_Specific(Op1), m_Value(Y)))) + return BinaryOperator::CreateNUWShl(ConstantInt::get(Ty, 1), Y); + + // X / (X * Y) -> 1 / Y if the multiplication does not overflow. + if (match(Op1, m_c_Mul(m_Specific(Op0), m_Value(Y)))) { + bool HasNSW = cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap(); + bool HasNUW = cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap(); + if ((IsSigned && HasNSW) || (!IsSigned && HasNUW)) { + replaceOperand(I, 0, ConstantInt::get(Ty, 1)); + replaceOperand(I, 1, Y); + return &I; + } + } + + return nullptr; +} + +static const unsigned MaxDepth = 6; + +// Take the exact integer log2 of the value. If DoFold is true, create the +// actual instructions, otherwise return a non-null dummy value. Return nullptr +// on failure. +static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, + bool DoFold) { + auto IfFold = [DoFold](function_ref<Value *()> Fn) { + if (!DoFold) + return reinterpret_cast<Value *>(-1); + return Fn(); + }; + + // FIXME: assert that Op1 isn't/doesn't contain undef. + + // log2(2^C) -> C + if (match(Op, m_Power2())) + return IfFold([&]() { + Constant *C = ConstantExpr::getExactLogBase2(cast<Constant>(Op)); + if (!C) + llvm_unreachable("Failed to constant fold udiv -> logbase2"); + return C; + }); + + // The remaining tests are all recursive, so bail out if we hit the limit. + if (Depth++ == MaxDepth) + return nullptr; + + // log2(zext X) -> zext log2(X) + // FIXME: Require one use? + Value *X, *Y; + if (match(Op, m_ZExt(m_Value(X)))) + if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) + return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); }); + + // log2(X << Y) -> log2(X) + Y + // FIXME: Require one use unless X is 1? + if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) + if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) + return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + + // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y) + // FIXME: missed optimization: if one of the hands of select is/contains + // undef, just directly pick the other one. + // FIXME: can both hands contain undef? + // FIXME: Require one use? + if (SelectInst *SI = dyn_cast<SelectInst>(Op)) + if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold)) + if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold)) + return IfFold([&]() { + return Builder.CreateSelect(SI->getOperand(0), LogX, LogY); + }); + + // log2(umin(X, Y)) -> umin(log2(X), log2(Y)) + // log2(umax(X, Y)) -> umax(log2(X), log2(Y)) + auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op); + if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) + if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold)) + if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold)) + return IfFold([&]() { + return Builder.CreateBinaryIntrinsic( + MinMax->getIntrinsicID(), LogX, LogY); + }); + + return nullptr; +} + +/// If we have zero-extended operands of an unsigned div or rem, we may be able +/// to narrow the operation (sink the zext below the math). +static Instruction *narrowUDivURem(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Instruction::BinaryOps Opcode = I.getOpcode(); + Value *N = I.getOperand(0); + Value *D = I.getOperand(1); + Type *Ty = I.getType(); + Value *X, *Y; + if (match(N, m_ZExt(m_Value(X))) && match(D, m_ZExt(m_Value(Y))) && + X->getType() == Y->getType() && (N->hasOneUse() || D->hasOneUse())) { + // udiv (zext X), (zext Y) --> zext (udiv X, Y) + // urem (zext X), (zext Y) --> zext (urem X, Y) + Value *NarrowOp = Builder.CreateBinOp(Opcode, X, Y); + return new ZExtInst(NarrowOp, Ty); + } + + Constant *C; + if ((match(N, m_OneUse(m_ZExt(m_Value(X)))) && match(D, m_Constant(C))) || + (match(D, m_OneUse(m_ZExt(m_Value(X)))) && match(N, m_Constant(C)))) { + // If the constant is the same in the smaller type, use the narrow version. + Constant *TruncC = ConstantExpr::getTrunc(C, X->getType()); + if (ConstantExpr::getZExt(TruncC, Ty) != C) + return nullptr; + + // udiv (zext X), C --> zext (udiv X, C') + // urem (zext X), C --> zext (urem X, C') + // udiv C, (zext X) --> zext (udiv C', X) + // urem C, (zext X) --> zext (urem C', X) + Value *NarrowOp = isa<Constant>(D) ? Builder.CreateBinOp(Opcode, X, TruncC) + : Builder.CreateBinOp(Opcode, TruncC, X); + return new ZExtInst(NarrowOp, Ty); + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { + if (Value *V = simplifyUDivInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldVectorBinop(I)) + return X; + + // Handle the integer div common cases + if (Instruction *Common = commonIDivTransforms(I)) + return Common; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *X; + const APInt *C1, *C2; + if (match(Op0, m_LShr(m_Value(X), m_APInt(C1))) && match(Op1, m_APInt(C2))) { + // (X lshr C1) udiv C2 --> X udiv (C2 << C1) + bool Overflow; + APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow); + if (!Overflow) { + bool IsExact = I.isExact() && match(Op0, m_Exact(m_Value())); + BinaryOperator *BO = BinaryOperator::CreateUDiv( + X, ConstantInt::get(X->getType(), C2ShlC1)); + if (IsExact) + BO->setIsExact(); + return BO; + } + } + + // Op0 / C where C is large (negative) --> zext (Op0 >= C) + // TODO: Could use isKnownNegative() to handle non-constant values. + Type *Ty = I.getType(); + if (match(Op1, m_Negative())) { + Value *Cmp = Builder.CreateICmpUGE(Op0, Op1); + return CastInst::CreateZExtOrBitCast(Cmp, Ty); + } + // Op0 / (sext i1 X) --> zext (Op0 == -1) (if X is 0, the div is undefined) + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty)); + return CastInst::CreateZExtOrBitCast(Cmp, Ty); + } + + if (Instruction *NarrowDiv = narrowUDivURem(I, Builder)) + return NarrowDiv; + + // If the udiv operands are non-overflowing multiplies with a common operand, + // then eliminate the common factor: + // (A * B) / (A * X) --> B / X (and commuted variants) + // TODO: The code would be reduced if we had m_c_NUWMul pattern matching. + // TODO: If -reassociation handled this generally, we could remove this. + Value *A, *B; + if (match(Op0, m_NUWMul(m_Value(A), m_Value(B)))) { + if (match(Op1, m_NUWMul(m_Specific(A), m_Value(X))) || + match(Op1, m_NUWMul(m_Value(X), m_Specific(A)))) + return BinaryOperator::CreateUDiv(B, X); + if (match(Op1, m_NUWMul(m_Specific(B), m_Value(X))) || + match(Op1, m_NUWMul(m_Value(X), m_Specific(B)))) + return BinaryOperator::CreateUDiv(A, X); + } + + // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away. + if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true); + return replaceInstUsesWith( + I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact())); + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { + if (Value *V = simplifySDivInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldVectorBinop(I)) + return X; + + // Handle the integer div common cases + if (Instruction *Common = commonIDivTransforms(I)) + return Common; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); + Value *X; + // sdiv Op0, -1 --> -Op0 + // sdiv Op0, (sext i1 X) --> -Op0 (because if X is 0, the op is undefined) + if (match(Op1, m_AllOnes()) || + (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) + return BinaryOperator::CreateNeg(Op0); + + // X / INT_MIN --> X == INT_MIN + if (match(Op1, m_SignMask())) + return new ZExtInst(Builder.CreateICmpEQ(Op0, Op1), Ty); + + // sdiv exact X, 1<<C --> ashr exact X, C iff 1<<C is non-negative + // sdiv exact X, -1<<C --> -(ashr exact X, C) + if (I.isExact() && ((match(Op1, m_Power2()) && match(Op1, m_NonNegative())) || + match(Op1, m_NegatedPower2()))) { + bool DivisorWasNegative = match(Op1, m_NegatedPower2()); + if (DivisorWasNegative) + Op1 = ConstantExpr::getNeg(cast<Constant>(Op1)); + auto *AShr = BinaryOperator::CreateExactAShr( + Op0, ConstantExpr::getExactLogBase2(cast<Constant>(Op1)), I.getName()); + if (!DivisorWasNegative) + return AShr; + Builder.Insert(AShr); + AShr->setName(I.getName() + ".neg"); + return BinaryOperator::CreateNeg(AShr, I.getName()); + } + + const APInt *Op1C; + if (match(Op1, m_APInt(Op1C))) { + // If the dividend is sign-extended and the constant divisor is small enough + // to fit in the source type, shrink the division to the narrower type: + // (sext X) sdiv C --> sext (X sdiv C) + Value *Op0Src; + if (match(Op0, m_OneUse(m_SExt(m_Value(Op0Src)))) && + Op0Src->getType()->getScalarSizeInBits() >= Op1C->getMinSignedBits()) { + + // In the general case, we need to make sure that the dividend is not the + // minimum signed value because dividing that by -1 is UB. But here, we + // know that the -1 divisor case is already handled above. + + Constant *NarrowDivisor = + ConstantExpr::getTrunc(cast<Constant>(Op1), Op0Src->getType()); + Value *NarrowOp = Builder.CreateSDiv(Op0Src, NarrowDivisor); + return new SExtInst(NarrowOp, Ty); + } + + // -X / C --> X / -C (if the negation doesn't overflow). + // TODO: This could be enhanced to handle arbitrary vector constants by + // checking if all elements are not the min-signed-val. + if (!Op1C->isMinSignedValue() && + match(Op0, m_NSWSub(m_Zero(), m_Value(X)))) { + Constant *NegC = ConstantInt::get(Ty, -(*Op1C)); + Instruction *BO = BinaryOperator::CreateSDiv(X, NegC); + BO->setIsExact(I.isExact()); + return BO; + } + } + + // -X / Y --> -(X / Y) + Value *Y; + if (match(&I, m_SDiv(m_OneUse(m_NSWSub(m_Zero(), m_Value(X))), m_Value(Y)))) + return BinaryOperator::CreateNSWNeg( + Builder.CreateSDiv(X, Y, I.getName(), I.isExact())); + + // abs(X) / X --> X > -1 ? 1 : -1 + // X / abs(X) --> X > -1 ? 1 : -1 + if (match(&I, m_c_BinOp( + m_OneUse(m_Intrinsic<Intrinsic::abs>(m_Value(X), m_One())), + m_Deferred(X)))) { + Value *Cond = Builder.CreateIsNotNeg(X); + return SelectInst::Create(Cond, ConstantInt::get(Ty, 1), + ConstantInt::getAllOnesValue(Ty)); + } + + // If the sign bits of both operands are zero (i.e. we can prove they are + // unsigned inputs), turn this into a udiv. + APInt Mask(APInt::getSignMask(Ty->getScalarSizeInBits())); + if (MaskedValueIsZero(Op0, Mask, 0, &I)) { + if (MaskedValueIsZero(Op1, Mask, 0, &I)) { + // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set + auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + BO->setIsExact(I.isExact()); + return BO; + } + + if (match(Op1, m_NegatedPower2())) { + // X sdiv (-(1 << C)) -> -(X sdiv (1 << C)) -> + // -> -(X udiv (1 << C)) -> -(X u>> C) + Constant *CNegLog2 = ConstantExpr::getExactLogBase2( + ConstantExpr::getNeg(cast<Constant>(Op1))); + Value *Shr = Builder.CreateLShr(Op0, CNegLog2, I.getName(), I.isExact()); + return BinaryOperator::CreateNeg(Shr); + } + + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { + // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y) + // Safe because the only negative value (1 << Y) can take on is + // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have + // the sign bit set. + auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + BO->setIsExact(I.isExact()); + return BO; + } + } + + return nullptr; +} + +/// Remove negation and try to convert division into multiplication. +static Instruction *foldFDivConstantDivisor(BinaryOperator &I) { + Constant *C; + if (!match(I.getOperand(1), m_Constant(C))) + return nullptr; + + // -X / C --> X / -C + Value *X; + if (match(I.getOperand(0), m_FNeg(m_Value(X)))) + return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + + // If the constant divisor has an exact inverse, this is always safe. If not, + // then we can still create a reciprocal if fast-math-flags allow it and the + // constant is a regular number (not zero, infinite, or denormal). + if (!(C->hasExactInverseFP() || (I.hasAllowReciprocal() && C->isNormalFP()))) + return nullptr; + + // Disallow denormal constants because we don't know what would happen + // on all targets. + // TODO: Use Intrinsic::canonicalize or let function attributes tell us that + // denorms are flushed? + auto *RecipC = ConstantExpr::getFDiv(ConstantFP::get(I.getType(), 1.0), C); + if (!RecipC->isNormalFP()) + return nullptr; + + // X / C --> X * (1 / C) + return BinaryOperator::CreateFMulFMF(I.getOperand(0), RecipC, &I); +} + +/// Remove negation and try to reassociate constant math. +static Instruction *foldFDivConstantDividend(BinaryOperator &I) { + Constant *C; + if (!match(I.getOperand(0), m_Constant(C))) + return nullptr; + + // C / -X --> -C / X + Value *X; + if (match(I.getOperand(1), m_FNeg(m_Value(X)))) + return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + + if (!I.hasAllowReassoc() || !I.hasAllowReciprocal()) + return nullptr; + + // Try to reassociate C / X expressions where X includes another constant. + Constant *C2, *NewC = nullptr; + if (match(I.getOperand(1), m_FMul(m_Value(X), m_Constant(C2)))) { + // C / (X * C2) --> (C / C2) / X + NewC = ConstantExpr::getFDiv(C, C2); + } else if (match(I.getOperand(1), m_FDiv(m_Value(X), m_Constant(C2)))) { + // C / (X / C2) --> (C * C2) / X + NewC = ConstantExpr::getFMul(C, C2); + } + // Disallow denormal constants because we don't know what would happen + // on all targets. + // TODO: Use Intrinsic::canonicalize or let function attributes tell us that + // denorms are flushed? + if (!NewC || !NewC->isNormalFP()) + return nullptr; + + return BinaryOperator::CreateFDivFMF(NewC, X, &I); +} + +/// Negate the exponent of pow/exp to fold division-by-pow() into multiply. +static Instruction *foldFDivPowDivisor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + auto *II = dyn_cast<IntrinsicInst>(Op1); + if (!II || !II->hasOneUse() || !I.hasAllowReassoc() || + !I.hasAllowReciprocal()) + return nullptr; + + // Z / pow(X, Y) --> Z * pow(X, -Y) + // Z / exp{2}(Y) --> Z * exp{2}(-Y) + // In the general case, this creates an extra instruction, but fmul allows + // for better canonicalization and optimization than fdiv. + Intrinsic::ID IID = II->getIntrinsicID(); + SmallVector<Value *> Args; + switch (IID) { + case Intrinsic::pow: + Args.push_back(II->getArgOperand(0)); + Args.push_back(Builder.CreateFNegFMF(II->getArgOperand(1), &I)); + break; + case Intrinsic::powi: { + // Require 'ninf' assuming that makes powi(X, -INT_MIN) acceptable. + // That is, X ** (huge negative number) is 0.0, ~1.0, or INF and so + // dividing by that is INF, ~1.0, or 0.0. Code that uses powi allows + // non-standard results, so this corner case should be acceptable if the + // code rules out INF values. + if (!I.hasNoInfs()) + return nullptr; + Args.push_back(II->getArgOperand(0)); + Args.push_back(Builder.CreateNeg(II->getArgOperand(1))); + Type *Tys[] = {I.getType(), II->getArgOperand(1)->getType()}; + Value *Pow = Builder.CreateIntrinsic(IID, Tys, Args, &I); + return BinaryOperator::CreateFMulFMF(Op0, Pow, &I); + } + case Intrinsic::exp: + case Intrinsic::exp2: + Args.push_back(Builder.CreateFNegFMF(II->getArgOperand(0), &I)); + break; + default: + return nullptr; + } + Value *Pow = Builder.CreateIntrinsic(IID, I.getType(), Args, &I); + return BinaryOperator::CreateFMulFMF(Op0, Pow, &I); +} + +Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { + Module *M = I.getModule(); + + if (Value *V = simplifyFDivInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + if (Instruction *R = foldFDivConstantDivisor(I)) + return R; + + if (Instruction *R = foldFDivConstantDividend(I)) + return R; + + if (Instruction *R = foldFPSignBitOps(I)) + return R; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (isa<Constant>(Op0)) + if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) + if (Instruction *R = FoldOpIntoSelect(I, SI)) + return R; + + if (isa<Constant>(Op1)) + if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI)) + return R; + + if (I.hasAllowReassoc() && I.hasAllowReciprocal()) { + Value *X, *Y; + if (match(Op0, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))) && + (!isa<Constant>(Y) || !isa<Constant>(Op1))) { + // (X / Y) / Z => X / (Y * Z) + Value *YZ = Builder.CreateFMulFMF(Y, Op1, &I); + return BinaryOperator::CreateFDivFMF(X, YZ, &I); + } + if (match(Op1, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))) && + (!isa<Constant>(Y) || !isa<Constant>(Op0))) { + // Z / (X / Y) => (Y * Z) / X + Value *YZ = Builder.CreateFMulFMF(Y, Op0, &I); + return BinaryOperator::CreateFDivFMF(YZ, X, &I); + } + // Z / (1.0 / Y) => (Y * Z) + // + // This is a special case of Z / (X / Y) => (Y * Z) / X, with X = 1.0. The + // m_OneUse check is avoided because even in the case of the multiple uses + // for 1.0/Y, the number of instructions remain the same and a division is + // replaced by a multiplication. + if (match(Op1, m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) + return BinaryOperator::CreateFMulFMF(Y, Op0, &I); + } + + if (I.hasAllowReassoc() && Op0->hasOneUse() && Op1->hasOneUse()) { + // sin(X) / cos(X) -> tan(X) + // cos(X) / sin(X) -> 1/tan(X) (cotangent) + Value *X; + bool IsTan = match(Op0, m_Intrinsic<Intrinsic::sin>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::cos>(m_Specific(X))); + bool IsCot = + !IsTan && match(Op0, m_Intrinsic<Intrinsic::cos>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::sin>(m_Specific(X))); + + if ((IsTan || IsCot) && hasFloatFn(M, &TLI, I.getType(), LibFunc_tan, + LibFunc_tanf, LibFunc_tanl)) { + IRBuilder<> B(&I); + IRBuilder<>::FastMathFlagGuard FMFGuard(B); + B.setFastMathFlags(I.getFastMathFlags()); + AttributeList Attrs = + cast<CallBase>(Op0)->getCalledFunction()->getAttributes(); + Value *Res = emitUnaryFloatFnCall(X, &TLI, LibFunc_tan, LibFunc_tanf, + LibFunc_tanl, B, Attrs); + if (IsCot) + Res = B.CreateFDiv(ConstantFP::get(I.getType(), 1.0), Res); + return replaceInstUsesWith(I, Res); + } + } + + // X / (X * Y) --> 1.0 / Y + // Reassociate to (X / X -> 1.0) is legal when NaNs are not allowed. + // We can ignore the possibility that X is infinity because INF/INF is NaN. + Value *X, *Y; + if (I.hasNoNaNs() && I.hasAllowReassoc() && + match(Op1, m_c_FMul(m_Specific(Op0), m_Value(Y)))) { + replaceOperand(I, 0, ConstantFP::get(I.getType(), 1.0)); + replaceOperand(I, 1, Y); + return &I; + } + + // X / fabs(X) -> copysign(1.0, X) + // fabs(X) / X -> copysign(1.0, X) + if (I.hasNoNaNs() && I.hasNoInfs() && + (match(&I, m_FDiv(m_Value(X), m_FAbs(m_Deferred(X)))) || + match(&I, m_FDiv(m_FAbs(m_Value(X)), m_Deferred(X))))) { + Value *V = Builder.CreateBinaryIntrinsic( + Intrinsic::copysign, ConstantFP::get(I.getType(), 1.0), X, &I); + return replaceInstUsesWith(I, V); + } + + if (Instruction *Mul = foldFDivPowDivisor(I, Builder)) + return Mul; + + return nullptr; +} + +/// This function implements the transforms common to both integer remainder +/// instructions (urem and srem). It is called by the visitors to those integer +/// remainder instructions. +/// Common integer remainder transforms +Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) { + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // The RHS is known non-zero. + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) + return replaceOperand(I, 1, V); + + // Handle cases involving: rem X, (select Cond, Y, Z) + if (simplifyDivRemOfSelectWithZeroOp(I)) + return &I; + + // If the divisor is a select-of-constants, try to constant fold all rem ops: + // C % (select Cond, TrueC, FalseC) --> select Cond, (C % TrueC), (C % FalseC) + // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds. + if (match(Op0, m_ImmConstant()) && + match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) { + if (Instruction *R = FoldOpIntoSelect(I, cast<SelectInst>(Op1), + /*FoldWithMultiUse*/ true)) + return R; + } + + if (isa<Constant>(Op1)) { + if (Instruction *Op0I = dyn_cast<Instruction>(Op0)) { + if (SelectInst *SI = dyn_cast<SelectInst>(Op0I)) { + if (Instruction *R = FoldOpIntoSelect(I, SI)) + return R; + } else if (auto *PN = dyn_cast<PHINode>(Op0I)) { + const APInt *Op1Int; + if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() && + (I.getOpcode() == Instruction::URem || + !Op1Int->isMinSignedValue())) { + // foldOpIntoPhi will speculate instructions to the end of the PHI's + // predecessor blocks, so do this only if we know the srem or urem + // will not fault. + if (Instruction *NV = foldOpIntoPhi(I, PN)) + return NV; + } + } + + // See if we can fold away this rem instruction. + if (SimplifyDemandedInstructionBits(I)) + return &I; + } + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { + if (Value *V = simplifyURemInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *common = commonIRemTransforms(I)) + return common; + + if (Instruction *NarrowRem = narrowUDivURem(I, Builder)) + return NarrowRem; + + // X urem Y -> X and Y-1, where Y is a power of 2, + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { + // This may increase instruction count, we don't enforce that Y is a + // constant. + Constant *N1 = Constant::getAllOnesValue(Ty); + Value *Add = Builder.CreateAdd(Op1, N1); + return BinaryOperator::CreateAnd(Op0, Add); + } + + // 1 urem X -> zext(X != 1) + if (match(Op0, m_One())) { + Value *Cmp = Builder.CreateICmpNE(Op1, ConstantInt::get(Ty, 1)); + return CastInst::CreateZExtOrBitCast(Cmp, Ty); + } + + // Op0 urem C -> Op0 < C ? Op0 : Op0 - C, where C >= signbit. + // Op0 must be frozen because we are increasing its number of uses. + if (match(Op1, m_Negative())) { + Value *F0 = Builder.CreateFreeze(Op0, Op0->getName() + ".fr"); + Value *Cmp = Builder.CreateICmpULT(F0, Op1); + Value *Sub = Builder.CreateSub(F0, Op1); + return SelectInst::Create(Cmp, F0, Sub); + } + + // If the divisor is a sext of a boolean, then the divisor must be max + // unsigned value (-1). Therefore, the remainder is Op0 unless Op0 is also + // max unsigned value. In that case, the remainder is 0: + // urem Op0, (sext i1 X) --> (Op0 == -1) ? 0 : Op0 + Value *X; + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty)); + return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Op0); + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitSRem(BinaryOperator &I) { + if (Value *V = simplifySRemInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldVectorBinop(I)) + return X; + + // Handle the integer rem common cases + if (Instruction *Common = commonIRemTransforms(I)) + return Common; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + { + const APInt *Y; + // X % -Y -> X % Y + if (match(Op1, m_Negative(Y)) && !Y->isMinSignedValue()) + return replaceOperand(I, 1, ConstantInt::get(I.getType(), -*Y)); + } + + // -X srem Y --> -(X srem Y) + Value *X, *Y; + if (match(&I, m_SRem(m_OneUse(m_NSWSub(m_Zero(), m_Value(X))), m_Value(Y)))) + return BinaryOperator::CreateNSWNeg(Builder.CreateSRem(X, Y)); + + // If the sign bits of both operands are zero (i.e. we can prove they are + // unsigned inputs), turn this into a urem. + APInt Mask(APInt::getSignMask(I.getType()->getScalarSizeInBits())); + if (MaskedValueIsZero(Op1, Mask, 0, &I) && + MaskedValueIsZero(Op0, Mask, 0, &I)) { + // X srem Y -> X urem Y, iff X and Y don't have sign bit set + return BinaryOperator::CreateURem(Op0, Op1, I.getName()); + } + + // If it's a constant vector, flip any negative values positive. + if (isa<ConstantVector>(Op1) || isa<ConstantDataVector>(Op1)) { + Constant *C = cast<Constant>(Op1); + unsigned VWidth = cast<FixedVectorType>(C->getType())->getNumElements(); + + bool hasNegative = false; + bool hasMissing = false; + for (unsigned i = 0; i != VWidth; ++i) { + Constant *Elt = C->getAggregateElement(i); + if (!Elt) { + hasMissing = true; + break; + } + + if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elt)) + if (RHS->isNegative()) + hasNegative = true; + } + + if (hasNegative && !hasMissing) { + SmallVector<Constant *, 16> Elts(VWidth); + for (unsigned i = 0; i != VWidth; ++i) { + Elts[i] = C->getAggregateElement(i); // Handle undef, etc. + if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elts[i])) { + if (RHS->isNegative()) + Elts[i] = cast<ConstantInt>(ConstantExpr::getNeg(RHS)); + } + } + + Constant *NewRHSV = ConstantVector::get(Elts); + if (NewRHSV != C) // Don't loop on -MININT + return replaceOperand(I, 1, NewRHSV); + } + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitFRem(BinaryOperator &I) { + if (Value *V = simplifyFRemInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + return nullptr; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp new file mode 100644 index 000000000000..c573b03f31a6 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -0,0 +1,559 @@ +//===- InstCombineNegator.cpp -----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements sinking of negation into expression trees, +// as long as that can be done without increasing instruction count. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/TargetFolder.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/DebugCounter.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +#include <cassert> +#include <cstdint> +#include <functional> +#include <tuple> +#include <type_traits> +#include <utility> + +namespace llvm { +class AssumptionCache; +class DataLayout; +class DominatorTree; +class LLVMContext; +} // namespace llvm + +using namespace llvm; + +#define DEBUG_TYPE "instcombine" + +STATISTIC(NegatorTotalNegationsAttempted, + "Negator: Number of negations attempted to be sinked"); +STATISTIC(NegatorNumTreesNegated, + "Negator: Number of negations successfully sinked"); +STATISTIC(NegatorMaxDepthVisited, "Negator: Maximal traversal depth ever " + "reached while attempting to sink negation"); +STATISTIC(NegatorTimesDepthLimitReached, + "Negator: How many times did the traversal depth limit was reached " + "during sinking"); +STATISTIC( + NegatorNumValuesVisited, + "Negator: Total number of values visited during attempts to sink negation"); +STATISTIC(NegatorNumNegationsFoundInCache, + "Negator: How many negations did we retrieve/reuse from cache"); +STATISTIC(NegatorMaxTotalValuesVisited, + "Negator: Maximal number of values ever visited while attempting to " + "sink negation"); +STATISTIC(NegatorNumInstructionsCreatedTotal, + "Negator: Number of new negated instructions created, total"); +STATISTIC(NegatorMaxInstructionsCreated, + "Negator: Maximal number of new instructions created during negation " + "attempt"); +STATISTIC(NegatorNumInstructionsNegatedSuccess, + "Negator: Number of new negated instructions created in successful " + "negation sinking attempts"); + +DEBUG_COUNTER(NegatorCounter, "instcombine-negator", + "Controls Negator transformations in InstCombine pass"); + +static cl::opt<bool> + NegatorEnabled("instcombine-negator-enabled", cl::init(true), + cl::desc("Should we attempt to sink negations?")); + +static cl::opt<unsigned> + NegatorMaxDepth("instcombine-negator-max-depth", + cl::init(NegatorDefaultMaxDepth), + cl::desc("What is the maximal lookup depth when trying to " + "check for viability of negation sinking.")); + +Negator::Negator(LLVMContext &C, const DataLayout &DL_, AssumptionCache &AC_, + const DominatorTree &DT_, bool IsTrulyNegation_) + : Builder(C, TargetFolder(DL_), + IRBuilderCallbackInserter([&](Instruction *I) { + ++NegatorNumInstructionsCreatedTotal; + NewInstructions.push_back(I); + })), + DL(DL_), AC(AC_), DT(DT_), IsTrulyNegation(IsTrulyNegation_) {} + +#if LLVM_ENABLE_STATS +Negator::~Negator() { + NegatorMaxTotalValuesVisited.updateMax(NumValuesVisitedInThisNegator); +} +#endif + +// Due to the InstCombine's worklist management, there are no guarantees that +// each instruction we'll encounter has been visited by InstCombine already. +// In particular, most importantly for us, that means we have to canonicalize +// constants to RHS ourselves, since that is helpful sometimes. +std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { + assert(I->getNumOperands() == 2 && "Only for binops!"); + std::array<Value *, 2> Ops{I->getOperand(0), I->getOperand(1)}; + if (I->isCommutative() && InstCombiner::getComplexity(I->getOperand(0)) < + InstCombiner::getComplexity(I->getOperand(1))) + std::swap(Ops[0], Ops[1]); + return Ops; +} + +// FIXME: can this be reworked into a worklist-based algorithm while preserving +// the depth-first, early bailout traversal? +LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { + // -(undef) -> undef. + if (match(V, m_Undef())) + return V; + + // In i1, negation can simply be ignored. + if (V->getType()->isIntOrIntVectorTy(1)) + return V; + + Value *X; + + // -(-(X)) -> X. + if (match(V, m_Neg(m_Value(X)))) + return X; + + // Integral constants can be freely negated. + if (match(V, m_AnyIntegralConstant())) + return ConstantExpr::getNeg(cast<Constant>(V), /*HasNUW=*/false, + /*HasNSW=*/false); + + // If we have a non-instruction, then give up. + if (!isa<Instruction>(V)) + return nullptr; + + // If we have started with a true negation (i.e. `sub 0, %y`), then if we've + // got instruction that does not require recursive reasoning, we can still + // negate it even if it has other uses, without increasing instruction count. + if (!V->hasOneUse() && !IsTrulyNegation) + return nullptr; + + auto *I = cast<Instruction>(V); + unsigned BitWidth = I->getType()->getScalarSizeInBits(); + + // We must preserve the insertion point and debug info that is set in the + // builder at the time this function is called. + InstCombiner::BuilderTy::InsertPointGuard Guard(Builder); + // And since we are trying to negate instruction I, that tells us about the + // insertion point and the debug info that we need to keep. + Builder.SetInsertPoint(I); + + // In some cases we can give the answer without further recursion. + switch (I->getOpcode()) { + case Instruction::Add: { + std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I); + // `inc` is always negatible. + if (match(Ops[1], m_One())) + return Builder.CreateNot(Ops[0], I->getName() + ".neg"); + break; + } + case Instruction::Xor: + // `not` is always negatible. + if (match(I, m_Not(m_Value(X)))) + return Builder.CreateAdd(X, ConstantInt::get(X->getType(), 1), + I->getName() + ".neg"); + break; + case Instruction::AShr: + case Instruction::LShr: { + // Right-shift sign bit smear is negatible. + const APInt *Op1Val; + if (match(I->getOperand(1), m_APInt(Op1Val)) && *Op1Val == BitWidth - 1) { + Value *BO = I->getOpcode() == Instruction::AShr + ? Builder.CreateLShr(I->getOperand(0), I->getOperand(1)) + : Builder.CreateAShr(I->getOperand(0), I->getOperand(1)); + if (auto *NewInstr = dyn_cast<Instruction>(BO)) { + NewInstr->copyIRFlags(I); + NewInstr->setName(I->getName() + ".neg"); + } + return BO; + } + // While we could negate exact arithmetic shift: + // ashr exact %x, C --> sdiv exact i8 %x, -1<<C + // iff C != 0 and C u< bitwidth(%x), we don't want to, + // because division is *THAT* much worse than a shift. + break; + } + case Instruction::SExt: + case Instruction::ZExt: + // `*ext` of i1 is always negatible + if (I->getOperand(0)->getType()->isIntOrIntVectorTy(1)) + return I->getOpcode() == Instruction::SExt + ? Builder.CreateZExt(I->getOperand(0), I->getType(), + I->getName() + ".neg") + : Builder.CreateSExt(I->getOperand(0), I->getType(), + I->getName() + ".neg"); + break; + case Instruction::Select: { + // If both arms of the select are constants, we don't need to recurse. + // Therefore, this transform is not limited by uses. + auto *Sel = cast<SelectInst>(I); + Constant *TrueC, *FalseC; + if (match(Sel->getTrueValue(), m_ImmConstant(TrueC)) && + match(Sel->getFalseValue(), m_ImmConstant(FalseC))) { + Constant *NegTrueC = ConstantExpr::getNeg(TrueC); + Constant *NegFalseC = ConstantExpr::getNeg(FalseC); + return Builder.CreateSelect(Sel->getCondition(), NegTrueC, NegFalseC, + I->getName() + ".neg", /*MDFrom=*/I); + } + break; + } + default: + break; // Other instructions require recursive reasoning. + } + + if (I->getOpcode() == Instruction::Sub && + (I->hasOneUse() || match(I->getOperand(0), m_ImmConstant()))) { + // `sub` is always negatible. + // However, only do this either if the old `sub` doesn't stick around, or + // it was subtracting from a constant. Otherwise, this isn't profitable. + return Builder.CreateSub(I->getOperand(1), I->getOperand(0), + I->getName() + ".neg"); + } + + // Some other cases, while still don't require recursion, + // are restricted to the one-use case. + if (!V->hasOneUse()) + return nullptr; + + switch (I->getOpcode()) { + case Instruction::And: { + Constant *ShAmt; + // sub(y,and(lshr(x,C),1)) --> add(ashr(shl(x,(BW-1)-C),BW-1),y) + if (match(I, m_c_And(m_OneUse(m_TruncOrSelf( + m_LShr(m_Value(X), m_ImmConstant(ShAmt)))), + m_One()))) { + unsigned BW = X->getType()->getScalarSizeInBits(); + Constant *BWMinusOne = ConstantInt::get(X->getType(), BW - 1); + Value *R = Builder.CreateShl(X, Builder.CreateSub(BWMinusOne, ShAmt)); + R = Builder.CreateAShr(R, BWMinusOne); + return Builder.CreateTruncOrBitCast(R, I->getType()); + } + break; + } + case Instruction::SDiv: + // `sdiv` is negatible if divisor is not undef/INT_MIN/1. + // While this is normally not behind a use-check, + // let's consider division to be special since it's costly. + if (auto *Op1C = dyn_cast<Constant>(I->getOperand(1))) { + if (!Op1C->containsUndefOrPoisonElement() && + Op1C->isNotMinSignedValue() && Op1C->isNotOneValue()) { + Value *BO = + Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(Op1C), + I->getName() + ".neg"); + if (auto *NewInstr = dyn_cast<Instruction>(BO)) + NewInstr->setIsExact(I->isExact()); + return BO; + } + } + break; + } + + // Rest of the logic is recursive, so if it's time to give up then it's time. + if (Depth > NegatorMaxDepth) { + LLVM_DEBUG(dbgs() << "Negator: reached maximal allowed traversal depth in " + << *V << ". Giving up.\n"); + ++NegatorTimesDepthLimitReached; + return nullptr; + } + + switch (I->getOpcode()) { + case Instruction::Freeze: { + // `freeze` is negatible if its operand is negatible. + Value *NegOp = negate(I->getOperand(0), Depth + 1); + if (!NegOp) // Early return. + return nullptr; + return Builder.CreateFreeze(NegOp, I->getName() + ".neg"); + } + case Instruction::PHI: { + // `phi` is negatible if all the incoming values are negatible. + auto *PHI = cast<PHINode>(I); + SmallVector<Value *, 4> NegatedIncomingValues(PHI->getNumOperands()); + for (auto I : zip(PHI->incoming_values(), NegatedIncomingValues)) { + if (!(std::get<1>(I) = + negate(std::get<0>(I), Depth + 1))) // Early return. + return nullptr; + } + // All incoming values are indeed negatible. Create negated PHI node. + PHINode *NegatedPHI = Builder.CreatePHI( + PHI->getType(), PHI->getNumOperands(), PHI->getName() + ".neg"); + for (auto I : zip(NegatedIncomingValues, PHI->blocks())) + NegatedPHI->addIncoming(std::get<0>(I), std::get<1>(I)); + return NegatedPHI; + } + case Instruction::Select: { + if (isKnownNegation(I->getOperand(1), I->getOperand(2))) { + // Of one hand of select is known to be negation of another hand, + // just swap the hands around. + auto *NewSelect = cast<SelectInst>(I->clone()); + // Just swap the operands of the select. + NewSelect->swapValues(); + // Don't swap prof metadata, we didn't change the branch behavior. + NewSelect->setName(I->getName() + ".neg"); + Builder.Insert(NewSelect); + return NewSelect; + } + // `select` is negatible if both hands of `select` are negatible. + Value *NegOp1 = negate(I->getOperand(1), Depth + 1); + if (!NegOp1) // Early return. + return nullptr; + Value *NegOp2 = negate(I->getOperand(2), Depth + 1); + if (!NegOp2) + return nullptr; + // Do preserve the metadata! + return Builder.CreateSelect(I->getOperand(0), NegOp1, NegOp2, + I->getName() + ".neg", /*MDFrom=*/I); + } + case Instruction::ShuffleVector: { + // `shufflevector` is negatible if both operands are negatible. + auto *Shuf = cast<ShuffleVectorInst>(I); + Value *NegOp0 = negate(I->getOperand(0), Depth + 1); + if (!NegOp0) // Early return. + return nullptr; + Value *NegOp1 = negate(I->getOperand(1), Depth + 1); + if (!NegOp1) + return nullptr; + return Builder.CreateShuffleVector(NegOp0, NegOp1, Shuf->getShuffleMask(), + I->getName() + ".neg"); + } + case Instruction::ExtractElement: { + // `extractelement` is negatible if source operand is negatible. + auto *EEI = cast<ExtractElementInst>(I); + Value *NegVector = negate(EEI->getVectorOperand(), Depth + 1); + if (!NegVector) // Early return. + return nullptr; + return Builder.CreateExtractElement(NegVector, EEI->getIndexOperand(), + I->getName() + ".neg"); + } + case Instruction::InsertElement: { + // `insertelement` is negatible if both the source vector and + // element-to-be-inserted are negatible. + auto *IEI = cast<InsertElementInst>(I); + Value *NegVector = negate(IEI->getOperand(0), Depth + 1); + if (!NegVector) // Early return. + return nullptr; + Value *NegNewElt = negate(IEI->getOperand(1), Depth + 1); + if (!NegNewElt) // Early return. + return nullptr; + return Builder.CreateInsertElement(NegVector, NegNewElt, IEI->getOperand(2), + I->getName() + ".neg"); + } + case Instruction::Trunc: { + // `trunc` is negatible if its operand is negatible. + Value *NegOp = negate(I->getOperand(0), Depth + 1); + if (!NegOp) // Early return. + return nullptr; + return Builder.CreateTrunc(NegOp, I->getType(), I->getName() + ".neg"); + } + case Instruction::Shl: { + // `shl` is negatible if the first operand is negatible. + if (Value *NegOp0 = negate(I->getOperand(0), Depth + 1)) + return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg"); + // Otherwise, `shl %x, C` can be interpreted as `mul %x, 1<<C`. + auto *Op1C = dyn_cast<Constant>(I->getOperand(1)); + if (!Op1C) // Early return. + return nullptr; + return Builder.CreateMul( + I->getOperand(0), + ConstantExpr::getShl(Constant::getAllOnesValue(Op1C->getType()), Op1C), + I->getName() + ".neg"); + } + case Instruction::Or: { + if (!haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL, &AC, I, + &DT)) + return nullptr; // Don't know how to handle `or` in general. + std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I); + // `or`/`add` are interchangeable when operands have no common bits set. + // `inc` is always negatible. + if (match(Ops[1], m_One())) + return Builder.CreateNot(Ops[0], I->getName() + ".neg"); + // Else, just defer to Instruction::Add handling. + LLVM_FALLTHROUGH; + } + case Instruction::Add: { + // `add` is negatible if both of its operands are negatible. + SmallVector<Value *, 2> NegatedOps, NonNegatedOps; + for (Value *Op : I->operands()) { + // Can we sink the negation into this operand? + if (Value *NegOp = negate(Op, Depth + 1)) { + NegatedOps.emplace_back(NegOp); // Successfully negated operand! + continue; + } + // Failed to sink negation into this operand. IFF we started from negation + // and we manage to sink negation into one operand, we can still do this. + if (!IsTrulyNegation) + return nullptr; + NonNegatedOps.emplace_back(Op); // Just record which operand that was. + } + assert((NegatedOps.size() + NonNegatedOps.size()) == 2 && + "Internal consistency check failed."); + // Did we manage to sink negation into both of the operands? + if (NegatedOps.size() == 2) // Then we get to keep the `add`! + return Builder.CreateAdd(NegatedOps[0], NegatedOps[1], + I->getName() + ".neg"); + assert(IsTrulyNegation && "We should have early-exited then."); + // Completely failed to sink negation? + if (NonNegatedOps.size() == 2) + return nullptr; + // 0-(a+b) --> (-a)-b + return Builder.CreateSub(NegatedOps[0], NonNegatedOps[0], + I->getName() + ".neg"); + } + case Instruction::Xor: { + std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I); + // `xor` is negatible if one of its operands is invertible. + // FIXME: InstCombineInverter? But how to connect Inverter and Negator? + if (auto *C = dyn_cast<Constant>(Ops[1])) { + Value *Xor = Builder.CreateXor(Ops[0], ConstantExpr::getNot(C)); + return Builder.CreateAdd(Xor, ConstantInt::get(Xor->getType(), 1), + I->getName() + ".neg"); + } + return nullptr; + } + case Instruction::Mul: { + std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I); + // `mul` is negatible if one of its operands is negatible. + Value *NegatedOp, *OtherOp; + // First try the second operand, in case it's a constant it will be best to + // just invert it instead of sinking the `neg` deeper. + if (Value *NegOp1 = negate(Ops[1], Depth + 1)) { + NegatedOp = NegOp1; + OtherOp = Ops[0]; + } else if (Value *NegOp0 = negate(Ops[0], Depth + 1)) { + NegatedOp = NegOp0; + OtherOp = Ops[1]; + } else + // Can't negate either of them. + return nullptr; + return Builder.CreateMul(NegatedOp, OtherOp, I->getName() + ".neg"); + } + default: + return nullptr; // Don't know, likely not negatible for free. + } + + llvm_unreachable("Can't get here. We always return from switch."); +} + +LLVM_NODISCARD Value *Negator::negate(Value *V, unsigned Depth) { + NegatorMaxDepthVisited.updateMax(Depth); + ++NegatorNumValuesVisited; + +#if LLVM_ENABLE_STATS + ++NumValuesVisitedInThisNegator; +#endif + +#ifndef NDEBUG + // We can't ever have a Value with such an address. + Value *Placeholder = reinterpret_cast<Value *>(static_cast<uintptr_t>(-1)); +#endif + + // Did we already try to negate this value? + auto NegationsCacheIterator = NegationsCache.find(V); + if (NegationsCacheIterator != NegationsCache.end()) { + ++NegatorNumNegationsFoundInCache; + Value *NegatedV = NegationsCacheIterator->second; + assert(NegatedV != Placeholder && "Encountered a cycle during negation."); + return NegatedV; + } + +#ifndef NDEBUG + // We did not find a cached result for negation of V. While there, + // let's temporairly cache a placeholder value, with the idea that if later + // during negation we fetch it from cache, we'll know we're in a cycle. + NegationsCache[V] = Placeholder; +#endif + + // No luck. Try negating it for real. + Value *NegatedV = visitImpl(V, Depth); + // And cache the (real) result for the future. + NegationsCache[V] = NegatedV; + + return NegatedV; +} + +LLVM_NODISCARD Optional<Negator::Result> Negator::run(Value *Root) { + Value *Negated = negate(Root, /*Depth=*/0); + if (!Negated) { + // We must cleanup newly-inserted instructions, to avoid any potential + // endless combine looping. + for (Instruction *I : llvm::reverse(NewInstructions)) + I->eraseFromParent(); + return llvm::None; + } + return std::make_pair(ArrayRef<Instruction *>(NewInstructions), Negated); +} + +LLVM_NODISCARD Value *Negator::Negate(bool LHSIsZero, Value *Root, + InstCombinerImpl &IC) { + ++NegatorTotalNegationsAttempted; + LLVM_DEBUG(dbgs() << "Negator: attempting to sink negation into " << *Root + << "\n"); + + if (!NegatorEnabled || !DebugCounter::shouldExecute(NegatorCounter)) + return nullptr; + + Negator N(Root->getContext(), IC.getDataLayout(), IC.getAssumptionCache(), + IC.getDominatorTree(), LHSIsZero); + Optional<Result> Res = N.run(Root); + if (!Res) { // Negation failed. + LLVM_DEBUG(dbgs() << "Negator: failed to sink negation into " << *Root + << "\n"); + return nullptr; + } + + LLVM_DEBUG(dbgs() << "Negator: successfully sunk negation into " << *Root + << "\n NEW: " << *Res->second << "\n"); + ++NegatorNumTreesNegated; + + // We must temporarily unset the 'current' insertion point and DebugLoc of the + // InstCombine's IRBuilder so that it won't interfere with the ones we have + // already specified when producing negated instructions. + InstCombiner::BuilderTy::InsertPointGuard Guard(IC.Builder); + IC.Builder.ClearInsertionPoint(); + IC.Builder.SetCurrentDebugLocation(DebugLoc()); + + // And finally, we must add newly-created instructions into the InstCombine's + // worklist (in a proper order!) so it can attempt to combine them. + LLVM_DEBUG(dbgs() << "Negator: Propagating " << Res->first.size() + << " instrs to InstCombine\n"); + NegatorMaxInstructionsCreated.updateMax(Res->first.size()); + NegatorNumInstructionsNegatedSuccess += Res->first.size(); + + // They are in def-use order, so nothing fancy, just insert them in order. + for (Instruction *I : Res->first) + IC.Builder.Insert(I, I->getName()); + + // And return the new root. + return Res->second; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp new file mode 100644 index 000000000000..90a796a0939e --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -0,0 +1,1557 @@ +//===- InstCombinePHI.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the visitPHINode function. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +#include "llvm/Transforms/Utils/Local.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "instcombine" + +static cl::opt<unsigned> +MaxNumPhis("instcombine-max-num-phis", cl::init(512), + cl::desc("Maximum number phis to handle in intptr/ptrint folding")); + +STATISTIC(NumPHIsOfInsertValues, + "Number of phi-of-insertvalue turned into insertvalue-of-phis"); +STATISTIC(NumPHIsOfExtractValues, + "Number of phi-of-extractvalue turned into extractvalue-of-phi"); +STATISTIC(NumPHICSEs, "Number of PHI's that got CSE'd"); + +/// The PHI arguments will be folded into a single operation with a PHI node +/// as input. The debug location of the single operation will be the merged +/// locations of the original PHI node arguments. +void InstCombinerImpl::PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN) { + auto *FirstInst = cast<Instruction>(PN.getIncomingValue(0)); + Inst->setDebugLoc(FirstInst->getDebugLoc()); + // We do not expect a CallInst here, otherwise, N-way merging of DebugLoc + // will be inefficient. + assert(!isa<CallInst>(Inst)); + + for (Value *V : drop_begin(PN.incoming_values())) { + auto *I = cast<Instruction>(V); + Inst->applyMergedLocation(Inst->getDebugLoc(), I->getDebugLoc()); + } +} + +// Replace Integer typed PHI PN if the PHI's value is used as a pointer value. +// If there is an existing pointer typed PHI that produces the same value as PN, +// replace PN and the IntToPtr operation with it. Otherwise, synthesize a new +// PHI node: +// +// Case-1: +// bb1: +// int_init = PtrToInt(ptr_init) +// br label %bb2 +// bb2: +// int_val = PHI([int_init, %bb1], [int_val_inc, %bb2] +// ptr_val = PHI([ptr_init, %bb1], [ptr_val_inc, %bb2] +// ptr_val2 = IntToPtr(int_val) +// ... +// use(ptr_val2) +// ptr_val_inc = ... +// inc_val_inc = PtrToInt(ptr_val_inc) +// +// ==> +// bb1: +// br label %bb2 +// bb2: +// ptr_val = PHI([ptr_init, %bb1], [ptr_val_inc, %bb2] +// ... +// use(ptr_val) +// ptr_val_inc = ... +// +// Case-2: +// bb1: +// int_ptr = BitCast(ptr_ptr) +// int_init = Load(int_ptr) +// br label %bb2 +// bb2: +// int_val = PHI([int_init, %bb1], [int_val_inc, %bb2] +// ptr_val2 = IntToPtr(int_val) +// ... +// use(ptr_val2) +// ptr_val_inc = ... +// inc_val_inc = PtrToInt(ptr_val_inc) +// ==> +// bb1: +// ptr_init = Load(ptr_ptr) +// br label %bb2 +// bb2: +// ptr_val = PHI([ptr_init, %bb1], [ptr_val_inc, %bb2] +// ... +// use(ptr_val) +// ptr_val_inc = ... +// ... +// +Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { + if (!PN.getType()->isIntegerTy()) + return nullptr; + if (!PN.hasOneUse()) + return nullptr; + + auto *IntToPtr = dyn_cast<IntToPtrInst>(PN.user_back()); + if (!IntToPtr) + return nullptr; + + // Check if the pointer is actually used as pointer: + auto HasPointerUse = [](Instruction *IIP) { + for (User *U : IIP->users()) { + Value *Ptr = nullptr; + if (LoadInst *LoadI = dyn_cast<LoadInst>(U)) { + Ptr = LoadI->getPointerOperand(); + } else if (StoreInst *SI = dyn_cast<StoreInst>(U)) { + Ptr = SI->getPointerOperand(); + } else if (GetElementPtrInst *GI = dyn_cast<GetElementPtrInst>(U)) { + Ptr = GI->getPointerOperand(); + } + + if (Ptr && Ptr == IIP) + return true; + } + return false; + }; + + if (!HasPointerUse(IntToPtr)) + return nullptr; + + if (DL.getPointerSizeInBits(IntToPtr->getAddressSpace()) != + DL.getTypeSizeInBits(IntToPtr->getOperand(0)->getType())) + return nullptr; + + SmallVector<Value *, 4> AvailablePtrVals; + for (auto Incoming : zip(PN.blocks(), PN.incoming_values())) { + BasicBlock *BB = std::get<0>(Incoming); + Value *Arg = std::get<1>(Incoming); + + // First look backward: + if (auto *PI = dyn_cast<PtrToIntInst>(Arg)) { + AvailablePtrVals.emplace_back(PI->getOperand(0)); + continue; + } + + // Next look forward: + Value *ArgIntToPtr = nullptr; + for (User *U : Arg->users()) { + if (isa<IntToPtrInst>(U) && U->getType() == IntToPtr->getType() && + (DT.dominates(cast<Instruction>(U), BB) || + cast<Instruction>(U)->getParent() == BB)) { + ArgIntToPtr = U; + break; + } + } + + if (ArgIntToPtr) { + AvailablePtrVals.emplace_back(ArgIntToPtr); + continue; + } + + // If Arg is defined by a PHI, allow it. This will also create + // more opportunities iteratively. + if (isa<PHINode>(Arg)) { + AvailablePtrVals.emplace_back(Arg); + continue; + } + + // For a single use integer load: + auto *LoadI = dyn_cast<LoadInst>(Arg); + if (!LoadI) + return nullptr; + + if (!LoadI->hasOneUse()) + return nullptr; + + // Push the integer typed Load instruction into the available + // value set, and fix it up later when the pointer typed PHI + // is synthesized. + AvailablePtrVals.emplace_back(LoadI); + } + + // Now search for a matching PHI + auto *BB = PN.getParent(); + assert(AvailablePtrVals.size() == PN.getNumIncomingValues() && + "Not enough available ptr typed incoming values"); + PHINode *MatchingPtrPHI = nullptr; + unsigned NumPhis = 0; + for (PHINode &PtrPHI : BB->phis()) { + // FIXME: consider handling this in AggressiveInstCombine + if (NumPhis++ > MaxNumPhis) + return nullptr; + if (&PtrPHI == &PN || PtrPHI.getType() != IntToPtr->getType()) + continue; + if (any_of(zip(PN.blocks(), AvailablePtrVals), + [&](const auto &BlockAndValue) { + BasicBlock *BB = std::get<0>(BlockAndValue); + Value *V = std::get<1>(BlockAndValue); + return PtrPHI.getIncomingValueForBlock(BB) != V; + })) + continue; + MatchingPtrPHI = &PtrPHI; + break; + } + + if (MatchingPtrPHI) { + assert(MatchingPtrPHI->getType() == IntToPtr->getType() && + "Phi's Type does not match with IntToPtr"); + // The PtrToCast + IntToPtr will be simplified later + return CastInst::CreateBitOrPointerCast(MatchingPtrPHI, + IntToPtr->getOperand(0)->getType()); + } + + // If it requires a conversion for every PHI operand, do not do it. + if (all_of(AvailablePtrVals, [&](Value *V) { + return (V->getType() != IntToPtr->getType()) || isa<IntToPtrInst>(V); + })) + return nullptr; + + // If any of the operand that requires casting is a terminator + // instruction, do not do it. Similarly, do not do the transform if the value + // is PHI in a block with no insertion point, for example, a catchswitch + // block, since we will not be able to insert a cast after the PHI. + if (any_of(AvailablePtrVals, [&](Value *V) { + if (V->getType() == IntToPtr->getType()) + return false; + auto *Inst = dyn_cast<Instruction>(V); + if (!Inst) + return false; + if (Inst->isTerminator()) + return true; + auto *BB = Inst->getParent(); + if (isa<PHINode>(Inst) && BB->getFirstInsertionPt() == BB->end()) + return true; + return false; + })) + return nullptr; + + PHINode *NewPtrPHI = PHINode::Create( + IntToPtr->getType(), PN.getNumIncomingValues(), PN.getName() + ".ptr"); + + InsertNewInstBefore(NewPtrPHI, PN); + SmallDenseMap<Value *, Instruction *> Casts; + for (auto Incoming : zip(PN.blocks(), AvailablePtrVals)) { + auto *IncomingBB = std::get<0>(Incoming); + auto *IncomingVal = std::get<1>(Incoming); + + if (IncomingVal->getType() == IntToPtr->getType()) { + NewPtrPHI->addIncoming(IncomingVal, IncomingBB); + continue; + } + +#ifndef NDEBUG + LoadInst *LoadI = dyn_cast<LoadInst>(IncomingVal); + assert((isa<PHINode>(IncomingVal) || + IncomingVal->getType()->isPointerTy() || + (LoadI && LoadI->hasOneUse())) && + "Can not replace LoadInst with multiple uses"); +#endif + // Need to insert a BitCast. + // For an integer Load instruction with a single use, the load + IntToPtr + // cast will be simplified into a pointer load: + // %v = load i64, i64* %a.ip, align 8 + // %v.cast = inttoptr i64 %v to float ** + // ==> + // %v.ptrp = bitcast i64 * %a.ip to float ** + // %v.cast = load float *, float ** %v.ptrp, align 8 + Instruction *&CI = Casts[IncomingVal]; + if (!CI) { + CI = CastInst::CreateBitOrPointerCast(IncomingVal, IntToPtr->getType(), + IncomingVal->getName() + ".ptr"); + if (auto *IncomingI = dyn_cast<Instruction>(IncomingVal)) { + BasicBlock::iterator InsertPos(IncomingI); + InsertPos++; + BasicBlock *BB = IncomingI->getParent(); + if (isa<PHINode>(IncomingI)) + InsertPos = BB->getFirstInsertionPt(); + assert(InsertPos != BB->end() && "should have checked above"); + InsertNewInstBefore(CI, *InsertPos); + } else { + auto *InsertBB = &IncomingBB->getParent()->getEntryBlock(); + InsertNewInstBefore(CI, *InsertBB->getFirstInsertionPt()); + } + } + NewPtrPHI->addIncoming(CI, IncomingBB); + } + + // The PtrToCast + IntToPtr will be simplified later + return CastInst::CreateBitOrPointerCast(NewPtrPHI, + IntToPtr->getOperand(0)->getType()); +} + +// Remove RoundTrip IntToPtr/PtrToInt Cast on PHI-Operand and +// fold Phi-operand to bitcast. +Instruction *InstCombinerImpl::foldPHIArgIntToPtrToPHI(PHINode &PN) { + // convert ptr2int ( phi[ int2ptr(ptr2int(x))] ) --> ptr2int ( phi [ x ] ) + // Make sure all uses of phi are ptr2int. + if (!all_of(PN.users(), [](User *U) { return isa<PtrToIntInst>(U); })) + return nullptr; + + // Iterating over all operands to check presence of target pointers for + // optimization. + bool OperandWithRoundTripCast = false; + for (unsigned OpNum = 0; OpNum != PN.getNumIncomingValues(); ++OpNum) { + if (auto *NewOp = + simplifyIntToPtrRoundTripCast(PN.getIncomingValue(OpNum))) { + PN.setIncomingValue(OpNum, NewOp); + OperandWithRoundTripCast = true; + } + } + if (!OperandWithRoundTripCast) + return nullptr; + return &PN; +} + +/// If we have something like phi [insertvalue(a,b,0), insertvalue(c,d,0)], +/// turn this into a phi[a,c] and phi[b,d] and a single insertvalue. +Instruction * +InstCombinerImpl::foldPHIArgInsertValueInstructionIntoPHI(PHINode &PN) { + auto *FirstIVI = cast<InsertValueInst>(PN.getIncomingValue(0)); + + // Scan to see if all operands are `insertvalue`'s with the same indicies, + // and all have a single use. + for (Value *V : drop_begin(PN.incoming_values())) { + auto *I = dyn_cast<InsertValueInst>(V); + if (!I || !I->hasOneUser() || I->getIndices() != FirstIVI->getIndices()) + return nullptr; + } + + // For each operand of an `insertvalue` + std::array<PHINode *, 2> NewOperands; + for (int OpIdx : {0, 1}) { + auto *&NewOperand = NewOperands[OpIdx]; + // Create a new PHI node to receive the values the operand has in each + // incoming basic block. + NewOperand = PHINode::Create( + FirstIVI->getOperand(OpIdx)->getType(), PN.getNumIncomingValues(), + FirstIVI->getOperand(OpIdx)->getName() + ".pn"); + // And populate each operand's PHI with said values. + for (auto Incoming : zip(PN.blocks(), PN.incoming_values())) + NewOperand->addIncoming( + cast<InsertValueInst>(std::get<1>(Incoming))->getOperand(OpIdx), + std::get<0>(Incoming)); + InsertNewInstBefore(NewOperand, PN); + } + + // And finally, create `insertvalue` over the newly-formed PHI nodes. + auto *NewIVI = InsertValueInst::Create(NewOperands[0], NewOperands[1], + FirstIVI->getIndices(), PN.getName()); + + PHIArgMergedDebugLoc(NewIVI, PN); + ++NumPHIsOfInsertValues; + return NewIVI; +} + +/// If we have something like phi [extractvalue(a,0), extractvalue(b,0)], +/// turn this into a phi[a,b] and a single extractvalue. +Instruction * +InstCombinerImpl::foldPHIArgExtractValueInstructionIntoPHI(PHINode &PN) { + auto *FirstEVI = cast<ExtractValueInst>(PN.getIncomingValue(0)); + + // Scan to see if all operands are `extractvalue`'s with the same indicies, + // and all have a single use. + for (Value *V : drop_begin(PN.incoming_values())) { + auto *I = dyn_cast<ExtractValueInst>(V); + if (!I || !I->hasOneUser() || I->getIndices() != FirstEVI->getIndices() || + I->getAggregateOperand()->getType() != + FirstEVI->getAggregateOperand()->getType()) + return nullptr; + } + + // Create a new PHI node to receive the values the aggregate operand has + // in each incoming basic block. + auto *NewAggregateOperand = PHINode::Create( + FirstEVI->getAggregateOperand()->getType(), PN.getNumIncomingValues(), + FirstEVI->getAggregateOperand()->getName() + ".pn"); + // And populate the PHI with said values. + for (auto Incoming : zip(PN.blocks(), PN.incoming_values())) + NewAggregateOperand->addIncoming( + cast<ExtractValueInst>(std::get<1>(Incoming))->getAggregateOperand(), + std::get<0>(Incoming)); + InsertNewInstBefore(NewAggregateOperand, PN); + + // And finally, create `extractvalue` over the newly-formed PHI nodes. + auto *NewEVI = ExtractValueInst::Create(NewAggregateOperand, + FirstEVI->getIndices(), PN.getName()); + + PHIArgMergedDebugLoc(NewEVI, PN); + ++NumPHIsOfExtractValues; + return NewEVI; +} + +/// If we have something like phi [add (a,b), add(a,c)] and if a/b/c and the +/// adds all have a single user, turn this into a phi and a single binop. +Instruction *InstCombinerImpl::foldPHIArgBinOpIntoPHI(PHINode &PN) { + Instruction *FirstInst = cast<Instruction>(PN.getIncomingValue(0)); + assert(isa<BinaryOperator>(FirstInst) || isa<CmpInst>(FirstInst)); + unsigned Opc = FirstInst->getOpcode(); + Value *LHSVal = FirstInst->getOperand(0); + Value *RHSVal = FirstInst->getOperand(1); + + Type *LHSType = LHSVal->getType(); + Type *RHSType = RHSVal->getType(); + + // Scan to see if all operands are the same opcode, and all have one user. + for (Value *V : drop_begin(PN.incoming_values())) { + Instruction *I = dyn_cast<Instruction>(V); + if (!I || I->getOpcode() != Opc || !I->hasOneUser() || + // Verify type of the LHS matches so we don't fold cmp's of different + // types. + I->getOperand(0)->getType() != LHSType || + I->getOperand(1)->getType() != RHSType) + return nullptr; + + // If they are CmpInst instructions, check their predicates + if (CmpInst *CI = dyn_cast<CmpInst>(I)) + if (CI->getPredicate() != cast<CmpInst>(FirstInst)->getPredicate()) + return nullptr; + + // Keep track of which operand needs a phi node. + if (I->getOperand(0) != LHSVal) LHSVal = nullptr; + if (I->getOperand(1) != RHSVal) RHSVal = nullptr; + } + + // If both LHS and RHS would need a PHI, don't do this transformation, + // because it would increase the number of PHIs entering the block, + // which leads to higher register pressure. This is especially + // bad when the PHIs are in the header of a loop. + if (!LHSVal && !RHSVal) + return nullptr; + + // Otherwise, this is safe to transform! + + Value *InLHS = FirstInst->getOperand(0); + Value *InRHS = FirstInst->getOperand(1); + PHINode *NewLHS = nullptr, *NewRHS = nullptr; + if (!LHSVal) { + NewLHS = PHINode::Create(LHSType, PN.getNumIncomingValues(), + FirstInst->getOperand(0)->getName() + ".pn"); + NewLHS->addIncoming(InLHS, PN.getIncomingBlock(0)); + InsertNewInstBefore(NewLHS, PN); + LHSVal = NewLHS; + } + + if (!RHSVal) { + NewRHS = PHINode::Create(RHSType, PN.getNumIncomingValues(), + FirstInst->getOperand(1)->getName() + ".pn"); + NewRHS->addIncoming(InRHS, PN.getIncomingBlock(0)); + InsertNewInstBefore(NewRHS, PN); + RHSVal = NewRHS; + } + + // Add all operands to the new PHIs. + if (NewLHS || NewRHS) { + for (auto Incoming : drop_begin(zip(PN.blocks(), PN.incoming_values()))) { + BasicBlock *InBB = std::get<0>(Incoming); + Value *InVal = std::get<1>(Incoming); + Instruction *InInst = cast<Instruction>(InVal); + if (NewLHS) { + Value *NewInLHS = InInst->getOperand(0); + NewLHS->addIncoming(NewInLHS, InBB); + } + if (NewRHS) { + Value *NewInRHS = InInst->getOperand(1); + NewRHS->addIncoming(NewInRHS, InBB); + } + } + } + + if (CmpInst *CIOp = dyn_cast<CmpInst>(FirstInst)) { + CmpInst *NewCI = CmpInst::Create(CIOp->getOpcode(), CIOp->getPredicate(), + LHSVal, RHSVal); + PHIArgMergedDebugLoc(NewCI, PN); + return NewCI; + } + + BinaryOperator *BinOp = cast<BinaryOperator>(FirstInst); + BinaryOperator *NewBinOp = + BinaryOperator::Create(BinOp->getOpcode(), LHSVal, RHSVal); + + NewBinOp->copyIRFlags(PN.getIncomingValue(0)); + + for (Value *V : drop_begin(PN.incoming_values())) + NewBinOp->andIRFlags(V); + + PHIArgMergedDebugLoc(NewBinOp, PN); + return NewBinOp; +} + +Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) { + GetElementPtrInst *FirstInst =cast<GetElementPtrInst>(PN.getIncomingValue(0)); + + SmallVector<Value*, 16> FixedOperands(FirstInst->op_begin(), + FirstInst->op_end()); + // This is true if all GEP bases are allocas and if all indices into them are + // constants. + bool AllBasePointersAreAllocas = true; + + // We don't want to replace this phi if the replacement would require + // more than one phi, which leads to higher register pressure. This is + // especially bad when the PHIs are in the header of a loop. + bool NeededPhi = false; + + bool AllInBounds = true; + + // Scan to see if all operands are the same opcode, and all have one user. + for (Value *V : drop_begin(PN.incoming_values())) { + GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V); + if (!GEP || !GEP->hasOneUser() || + GEP->getSourceElementType() != FirstInst->getSourceElementType() || + GEP->getNumOperands() != FirstInst->getNumOperands()) + return nullptr; + + AllInBounds &= GEP->isInBounds(); + + // Keep track of whether or not all GEPs are of alloca pointers. + if (AllBasePointersAreAllocas && + (!isa<AllocaInst>(GEP->getOperand(0)) || + !GEP->hasAllConstantIndices())) + AllBasePointersAreAllocas = false; + + // Compare the operand lists. + for (unsigned Op = 0, E = FirstInst->getNumOperands(); Op != E; ++Op) { + if (FirstInst->getOperand(Op) == GEP->getOperand(Op)) + continue; + + // Don't merge two GEPs when two operands differ (introducing phi nodes) + // if one of the PHIs has a constant for the index. The index may be + // substantially cheaper to compute for the constants, so making it a + // variable index could pessimize the path. This also handles the case + // for struct indices, which must always be constant. + if (isa<ConstantInt>(FirstInst->getOperand(Op)) || + isa<ConstantInt>(GEP->getOperand(Op))) + return nullptr; + + if (FirstInst->getOperand(Op)->getType() != + GEP->getOperand(Op)->getType()) + return nullptr; + + // If we already needed a PHI for an earlier operand, and another operand + // also requires a PHI, we'd be introducing more PHIs than we're + // eliminating, which increases register pressure on entry to the PHI's + // block. + if (NeededPhi) + return nullptr; + + FixedOperands[Op] = nullptr; // Needs a PHI. + NeededPhi = true; + } + } + + // If all of the base pointers of the PHI'd GEPs are from allocas, don't + // bother doing this transformation. At best, this will just save a bit of + // offset calculation, but all the predecessors will have to materialize the + // stack address into a register anyway. We'd actually rather *clone* the + // load up into the predecessors so that we have a load of a gep of an alloca, + // which can usually all be folded into the load. + if (AllBasePointersAreAllocas) + return nullptr; + + // Otherwise, this is safe to transform. Insert PHI nodes for each operand + // that is variable. + SmallVector<PHINode*, 16> OperandPhis(FixedOperands.size()); + + bool HasAnyPHIs = false; + for (unsigned I = 0, E = FixedOperands.size(); I != E; ++I) { + if (FixedOperands[I]) + continue; // operand doesn't need a phi. + Value *FirstOp = FirstInst->getOperand(I); + PHINode *NewPN = + PHINode::Create(FirstOp->getType(), E, FirstOp->getName() + ".pn"); + InsertNewInstBefore(NewPN, PN); + + NewPN->addIncoming(FirstOp, PN.getIncomingBlock(0)); + OperandPhis[I] = NewPN; + FixedOperands[I] = NewPN; + HasAnyPHIs = true; + } + + // Add all operands to the new PHIs. + if (HasAnyPHIs) { + for (auto Incoming : drop_begin(zip(PN.blocks(), PN.incoming_values()))) { + BasicBlock *InBB = std::get<0>(Incoming); + Value *InVal = std::get<1>(Incoming); + GetElementPtrInst *InGEP = cast<GetElementPtrInst>(InVal); + + for (unsigned Op = 0, E = OperandPhis.size(); Op != E; ++Op) + if (PHINode *OpPhi = OperandPhis[Op]) + OpPhi->addIncoming(InGEP->getOperand(Op), InBB); + } + } + + Value *Base = FixedOperands[0]; + GetElementPtrInst *NewGEP = + GetElementPtrInst::Create(FirstInst->getSourceElementType(), Base, + makeArrayRef(FixedOperands).slice(1)); + if (AllInBounds) NewGEP->setIsInBounds(); + PHIArgMergedDebugLoc(NewGEP, PN); + return NewGEP; +} + +/// Return true if we know that it is safe to sink the load out of the block +/// that defines it. This means that it must be obvious the value of the load is +/// not changed from the point of the load to the end of the block it is in. +/// +/// Finally, it is safe, but not profitable, to sink a load targeting a +/// non-address-taken alloca. Doing so will cause us to not promote the alloca +/// to a register. +static bool isSafeAndProfitableToSinkLoad(LoadInst *L) { + BasicBlock::iterator BBI = L->getIterator(), E = L->getParent()->end(); + + for (++BBI; BBI != E; ++BBI) + if (BBI->mayWriteToMemory()) { + // Calls that only access inaccessible memory do not block sinking the + // load. + if (auto *CB = dyn_cast<CallBase>(BBI)) + if (CB->onlyAccessesInaccessibleMemory()) + continue; + return false; + } + + // Check for non-address taken alloca. If not address-taken already, it isn't + // profitable to do this xform. + if (AllocaInst *AI = dyn_cast<AllocaInst>(L->getOperand(0))) { + bool IsAddressTaken = false; + for (User *U : AI->users()) { + if (isa<LoadInst>(U)) continue; + if (StoreInst *SI = dyn_cast<StoreInst>(U)) { + // If storing TO the alloca, then the address isn't taken. + if (SI->getOperand(1) == AI) continue; + } + IsAddressTaken = true; + break; + } + + if (!IsAddressTaken && AI->isStaticAlloca()) + return false; + } + + // If this load is a load from a GEP with a constant offset from an alloca, + // then we don't want to sink it. In its present form, it will be + // load [constant stack offset]. Sinking it will cause us to have to + // materialize the stack addresses in each predecessor in a register only to + // do a shared load from register in the successor. + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(L->getOperand(0))) + if (AllocaInst *AI = dyn_cast<AllocaInst>(GEP->getOperand(0))) + if (AI->isStaticAlloca() && GEP->hasAllConstantIndices()) + return false; + + return true; +} + +Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { + LoadInst *FirstLI = cast<LoadInst>(PN.getIncomingValue(0)); + + // Can't forward swifterror through a phi. + if (FirstLI->getOperand(0)->isSwiftError()) + return nullptr; + + // FIXME: This is overconservative; this transform is allowed in some cases + // for atomic operations. + if (FirstLI->isAtomic()) + return nullptr; + + // When processing loads, we need to propagate two bits of information to the + // sunk load: whether it is volatile, and what its alignment is. + bool IsVolatile = FirstLI->isVolatile(); + Align LoadAlignment = FirstLI->getAlign(); + const unsigned LoadAddrSpace = FirstLI->getPointerAddressSpace(); + + // We can't sink the load if the loaded value could be modified between the + // load and the PHI. + if (FirstLI->getParent() != PN.getIncomingBlock(0) || + !isSafeAndProfitableToSinkLoad(FirstLI)) + return nullptr; + + // If the PHI is of volatile loads and the load block has multiple + // successors, sinking it would remove a load of the volatile value from + // the path through the other successor. + if (IsVolatile && + FirstLI->getParent()->getTerminator()->getNumSuccessors() != 1) + return nullptr; + + for (auto Incoming : drop_begin(zip(PN.blocks(), PN.incoming_values()))) { + BasicBlock *InBB = std::get<0>(Incoming); + Value *InVal = std::get<1>(Incoming); + LoadInst *LI = dyn_cast<LoadInst>(InVal); + if (!LI || !LI->hasOneUser() || LI->isAtomic()) + return nullptr; + + // Make sure all arguments are the same type of operation. + if (LI->isVolatile() != IsVolatile || + LI->getPointerAddressSpace() != LoadAddrSpace) + return nullptr; + + // Can't forward swifterror through a phi. + if (LI->getOperand(0)->isSwiftError()) + return nullptr; + + // We can't sink the load if the loaded value could be modified between + // the load and the PHI. + if (LI->getParent() != InBB || !isSafeAndProfitableToSinkLoad(LI)) + return nullptr; + + LoadAlignment = std::min(LoadAlignment, LI->getAlign()); + + // If the PHI is of volatile loads and the load block has multiple + // successors, sinking it would remove a load of the volatile value from + // the path through the other successor. + if (IsVolatile && LI->getParent()->getTerminator()->getNumSuccessors() != 1) + return nullptr; + } + + // Okay, they are all the same operation. Create a new PHI node of the + // correct type, and PHI together all of the LHS's of the instructions. + PHINode *NewPN = PHINode::Create(FirstLI->getOperand(0)->getType(), + PN.getNumIncomingValues(), + PN.getName()+".in"); + + Value *InVal = FirstLI->getOperand(0); + NewPN->addIncoming(InVal, PN.getIncomingBlock(0)); + LoadInst *NewLI = + new LoadInst(FirstLI->getType(), NewPN, "", IsVolatile, LoadAlignment); + + unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, + LLVMContext::MD_range, + LLVMContext::MD_invariant_load, + LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, + LLVMContext::MD_nonnull, + LLVMContext::MD_align, + LLVMContext::MD_dereferenceable, + LLVMContext::MD_dereferenceable_or_null, + LLVMContext::MD_access_group, + }; + + for (unsigned ID : KnownIDs) + NewLI->setMetadata(ID, FirstLI->getMetadata(ID)); + + // Add all operands to the new PHI and combine TBAA metadata. + for (auto Incoming : drop_begin(zip(PN.blocks(), PN.incoming_values()))) { + BasicBlock *BB = std::get<0>(Incoming); + Value *V = std::get<1>(Incoming); + LoadInst *LI = cast<LoadInst>(V); + combineMetadata(NewLI, LI, KnownIDs, true); + Value *NewInVal = LI->getOperand(0); + if (NewInVal != InVal) + InVal = nullptr; + NewPN->addIncoming(NewInVal, BB); + } + + if (InVal) { + // The new PHI unions all of the same values together. This is really + // common, so we handle it intelligently here for compile-time speed. + NewLI->setOperand(0, InVal); + delete NewPN; + } else { + InsertNewInstBefore(NewPN, PN); + } + + // If this was a volatile load that we are merging, make sure to loop through + // and mark all the input loads as non-volatile. If we don't do this, we will + // insert a new volatile load and the old ones will not be deletable. + if (IsVolatile) + for (Value *IncValue : PN.incoming_values()) + cast<LoadInst>(IncValue)->setVolatile(false); + + PHIArgMergedDebugLoc(NewLI, PN); + return NewLI; +} + +/// TODO: This function could handle other cast types, but then it might +/// require special-casing a cast from the 'i1' type. See the comment in +/// FoldPHIArgOpIntoPHI() about pessimizing illegal integer types. +Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) { + // We cannot create a new instruction after the PHI if the terminator is an + // EHPad because there is no valid insertion point. + if (Instruction *TI = Phi.getParent()->getTerminator()) + if (TI->isEHPad()) + return nullptr; + + // Early exit for the common case of a phi with two operands. These are + // handled elsewhere. See the comment below where we check the count of zexts + // and constants for more details. + unsigned NumIncomingValues = Phi.getNumIncomingValues(); + if (NumIncomingValues < 3) + return nullptr; + + // Find the narrower type specified by the first zext. + Type *NarrowType = nullptr; + for (Value *V : Phi.incoming_values()) { + if (auto *Zext = dyn_cast<ZExtInst>(V)) { + NarrowType = Zext->getSrcTy(); + break; + } + } + if (!NarrowType) + return nullptr; + + // Walk the phi operands checking that we only have zexts or constants that + // we can shrink for free. Store the new operands for the new phi. + SmallVector<Value *, 4> NewIncoming; + unsigned NumZexts = 0; + unsigned NumConsts = 0; + for (Value *V : Phi.incoming_values()) { + if (auto *Zext = dyn_cast<ZExtInst>(V)) { + // All zexts must be identical and have one user. + if (Zext->getSrcTy() != NarrowType || !Zext->hasOneUser()) + return nullptr; + NewIncoming.push_back(Zext->getOperand(0)); + NumZexts++; + } else if (auto *C = dyn_cast<Constant>(V)) { + // Make sure that constants can fit in the new type. + Constant *Trunc = ConstantExpr::getTrunc(C, NarrowType); + if (ConstantExpr::getZExt(Trunc, C->getType()) != C) + return nullptr; + NewIncoming.push_back(Trunc); + NumConsts++; + } else { + // If it's not a cast or a constant, bail out. + return nullptr; + } + } + + // The more common cases of a phi with no constant operands or just one + // variable operand are handled by FoldPHIArgOpIntoPHI() and foldOpIntoPhi() + // respectively. foldOpIntoPhi() wants to do the opposite transform that is + // performed here. It tries to replicate a cast in the phi operand's basic + // block to expose other folding opportunities. Thus, InstCombine will + // infinite loop without this check. + if (NumConsts == 0 || NumZexts < 2) + return nullptr; + + // All incoming values are zexts or constants that are safe to truncate. + // Create a new phi node of the narrow type, phi together all of the new + // operands, and zext the result back to the original type. + PHINode *NewPhi = PHINode::Create(NarrowType, NumIncomingValues, + Phi.getName() + ".shrunk"); + for (unsigned I = 0; I != NumIncomingValues; ++I) + NewPhi->addIncoming(NewIncoming[I], Phi.getIncomingBlock(I)); + + InsertNewInstBefore(NewPhi, Phi); + return CastInst::CreateZExtOrBitCast(NewPhi, Phi.getType()); +} + +/// If all operands to a PHI node are the same "unary" operator and they all are +/// only used by the PHI, PHI together their inputs, and do the operation once, +/// to the result of the PHI. +Instruction *InstCombinerImpl::foldPHIArgOpIntoPHI(PHINode &PN) { + // We cannot create a new instruction after the PHI if the terminator is an + // EHPad because there is no valid insertion point. + if (Instruction *TI = PN.getParent()->getTerminator()) + if (TI->isEHPad()) + return nullptr; + + Instruction *FirstInst = cast<Instruction>(PN.getIncomingValue(0)); + + if (isa<GetElementPtrInst>(FirstInst)) + return foldPHIArgGEPIntoPHI(PN); + if (isa<LoadInst>(FirstInst)) + return foldPHIArgLoadIntoPHI(PN); + if (isa<InsertValueInst>(FirstInst)) + return foldPHIArgInsertValueInstructionIntoPHI(PN); + if (isa<ExtractValueInst>(FirstInst)) + return foldPHIArgExtractValueInstructionIntoPHI(PN); + + // Scan the instruction, looking for input operations that can be folded away. + // If all input operands to the phi are the same instruction (e.g. a cast from + // the same type or "+42") we can pull the operation through the PHI, reducing + // code size and simplifying code. + Constant *ConstantOp = nullptr; + Type *CastSrcTy = nullptr; + + if (isa<CastInst>(FirstInst)) { + CastSrcTy = FirstInst->getOperand(0)->getType(); + + // Be careful about transforming integer PHIs. We don't want to pessimize + // the code by turning an i32 into an i1293. + if (PN.getType()->isIntegerTy() && CastSrcTy->isIntegerTy()) { + if (!shouldChangeType(PN.getType(), CastSrcTy)) + return nullptr; + } + } else if (isa<BinaryOperator>(FirstInst) || isa<CmpInst>(FirstInst)) { + // Can fold binop, compare or shift here if the RHS is a constant, + // otherwise call FoldPHIArgBinOpIntoPHI. + ConstantOp = dyn_cast<Constant>(FirstInst->getOperand(1)); + if (!ConstantOp) + return foldPHIArgBinOpIntoPHI(PN); + } else { + return nullptr; // Cannot fold this operation. + } + + // Check to see if all arguments are the same operation. + for (Value *V : drop_begin(PN.incoming_values())) { + Instruction *I = dyn_cast<Instruction>(V); + if (!I || !I->hasOneUser() || !I->isSameOperationAs(FirstInst)) + return nullptr; + if (CastSrcTy) { + if (I->getOperand(0)->getType() != CastSrcTy) + return nullptr; // Cast operation must match. + } else if (I->getOperand(1) != ConstantOp) { + return nullptr; + } + } + + // Okay, they are all the same operation. Create a new PHI node of the + // correct type, and PHI together all of the LHS's of the instructions. + PHINode *NewPN = PHINode::Create(FirstInst->getOperand(0)->getType(), + PN.getNumIncomingValues(), + PN.getName()+".in"); + + Value *InVal = FirstInst->getOperand(0); + NewPN->addIncoming(InVal, PN.getIncomingBlock(0)); + + // Add all operands to the new PHI. + for (auto Incoming : drop_begin(zip(PN.blocks(), PN.incoming_values()))) { + BasicBlock *BB = std::get<0>(Incoming); + Value *V = std::get<1>(Incoming); + Value *NewInVal = cast<Instruction>(V)->getOperand(0); + if (NewInVal != InVal) + InVal = nullptr; + NewPN->addIncoming(NewInVal, BB); + } + + Value *PhiVal; + if (InVal) { + // The new PHI unions all of the same values together. This is really + // common, so we handle it intelligently here for compile-time speed. + PhiVal = InVal; + delete NewPN; + } else { + InsertNewInstBefore(NewPN, PN); + PhiVal = NewPN; + } + + // Insert and return the new operation. + if (CastInst *FirstCI = dyn_cast<CastInst>(FirstInst)) { + CastInst *NewCI = CastInst::Create(FirstCI->getOpcode(), PhiVal, + PN.getType()); + PHIArgMergedDebugLoc(NewCI, PN); + return NewCI; + } + + if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(FirstInst)) { + BinOp = BinaryOperator::Create(BinOp->getOpcode(), PhiVal, ConstantOp); + BinOp->copyIRFlags(PN.getIncomingValue(0)); + + for (Value *V : drop_begin(PN.incoming_values())) + BinOp->andIRFlags(V); + + PHIArgMergedDebugLoc(BinOp, PN); + return BinOp; + } + + CmpInst *CIOp = cast<CmpInst>(FirstInst); + CmpInst *NewCI = CmpInst::Create(CIOp->getOpcode(), CIOp->getPredicate(), + PhiVal, ConstantOp); + PHIArgMergedDebugLoc(NewCI, PN); + return NewCI; +} + +/// Return true if this PHI node is only used by a PHI node cycle that is dead. +static bool isDeadPHICycle(PHINode *PN, + SmallPtrSetImpl<PHINode *> &PotentiallyDeadPHIs) { + if (PN->use_empty()) return true; + if (!PN->hasOneUse()) return false; + + // Remember this node, and if we find the cycle, return. + if (!PotentiallyDeadPHIs.insert(PN).second) + return true; + + // Don't scan crazily complex things. + if (PotentiallyDeadPHIs.size() == 16) + return false; + + if (PHINode *PU = dyn_cast<PHINode>(PN->user_back())) + return isDeadPHICycle(PU, PotentiallyDeadPHIs); + + return false; +} + +/// Return true if this phi node is always equal to NonPhiInVal. +/// This happens with mutually cyclic phi nodes like: +/// z = some value; x = phi (y, z); y = phi (x, z) +static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal, + SmallPtrSetImpl<PHINode*> &ValueEqualPHIs) { + // See if we already saw this PHI node. + if (!ValueEqualPHIs.insert(PN).second) + return true; + + // Don't scan crazily complex things. + if (ValueEqualPHIs.size() == 16) + return false; + + // Scan the operands to see if they are either phi nodes or are equal to + // the value. + for (Value *Op : PN->incoming_values()) { + if (PHINode *OpPN = dyn_cast<PHINode>(Op)) { + if (!PHIsEqualValue(OpPN, NonPhiInVal, ValueEqualPHIs)) + return false; + } else if (Op != NonPhiInVal) + return false; + } + + return true; +} + +/// Return an existing non-zero constant if this phi node has one, otherwise +/// return constant 1. +static ConstantInt *getAnyNonZeroConstInt(PHINode &PN) { + assert(isa<IntegerType>(PN.getType()) && "Expect only integer type phi"); + for (Value *V : PN.operands()) + if (auto *ConstVA = dyn_cast<ConstantInt>(V)) + if (!ConstVA->isZero()) + return ConstVA; + return ConstantInt::get(cast<IntegerType>(PN.getType()), 1); +} + +namespace { +struct PHIUsageRecord { + unsigned PHIId; // The ID # of the PHI (something determinstic to sort on) + unsigned Shift; // The amount shifted. + Instruction *Inst; // The trunc instruction. + + PHIUsageRecord(unsigned Pn, unsigned Sh, Instruction *User) + : PHIId(Pn), Shift(Sh), Inst(User) {} + + bool operator<(const PHIUsageRecord &RHS) const { + if (PHIId < RHS.PHIId) return true; + if (PHIId > RHS.PHIId) return false; + if (Shift < RHS.Shift) return true; + if (Shift > RHS.Shift) return false; + return Inst->getType()->getPrimitiveSizeInBits() < + RHS.Inst->getType()->getPrimitiveSizeInBits(); + } +}; + +struct LoweredPHIRecord { + PHINode *PN; // The PHI that was lowered. + unsigned Shift; // The amount shifted. + unsigned Width; // The width extracted. + + LoweredPHIRecord(PHINode *Phi, unsigned Sh, Type *Ty) + : PN(Phi), Shift(Sh), Width(Ty->getPrimitiveSizeInBits()) {} + + // Ctor form used by DenseMap. + LoweredPHIRecord(PHINode *Phi, unsigned Sh) : PN(Phi), Shift(Sh), Width(0) {} +}; +} // namespace + +namespace llvm { + template<> + struct DenseMapInfo<LoweredPHIRecord> { + static inline LoweredPHIRecord getEmptyKey() { + return LoweredPHIRecord(nullptr, 0); + } + static inline LoweredPHIRecord getTombstoneKey() { + return LoweredPHIRecord(nullptr, 1); + } + static unsigned getHashValue(const LoweredPHIRecord &Val) { + return DenseMapInfo<PHINode*>::getHashValue(Val.PN) ^ (Val.Shift>>3) ^ + (Val.Width>>3); + } + static bool isEqual(const LoweredPHIRecord &LHS, + const LoweredPHIRecord &RHS) { + return LHS.PN == RHS.PN && LHS.Shift == RHS.Shift && + LHS.Width == RHS.Width; + } + }; +} // namespace llvm + + +/// This is an integer PHI and we know that it has an illegal type: see if it is +/// only used by trunc or trunc(lshr) operations. If so, we split the PHI into +/// the various pieces being extracted. This sort of thing is introduced when +/// SROA promotes an aggregate to large integer values. +/// +/// TODO: The user of the trunc may be an bitcast to float/double/vector or an +/// inttoptr. We should produce new PHIs in the right type. +/// +Instruction *InstCombinerImpl::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { + // PHIUsers - Keep track of all of the truncated values extracted from a set + // of PHIs, along with their offset. These are the things we want to rewrite. + SmallVector<PHIUsageRecord, 16> PHIUsers; + + // PHIs are often mutually cyclic, so we keep track of a whole set of PHI + // nodes which are extracted from. PHIsToSlice is a set we use to avoid + // revisiting PHIs, PHIsInspected is a ordered list of PHIs that we need to + // check the uses of (to ensure they are all extracts). + SmallVector<PHINode*, 8> PHIsToSlice; + SmallPtrSet<PHINode*, 8> PHIsInspected; + + PHIsToSlice.push_back(&FirstPhi); + PHIsInspected.insert(&FirstPhi); + + for (unsigned PHIId = 0; PHIId != PHIsToSlice.size(); ++PHIId) { + PHINode *PN = PHIsToSlice[PHIId]; + + // Scan the input list of the PHI. If any input is an invoke, and if the + // input is defined in the predecessor, then we won't be split the critical + // edge which is required to insert a truncate. Because of this, we have to + // bail out. + for (auto Incoming : zip(PN->blocks(), PN->incoming_values())) { + BasicBlock *BB = std::get<0>(Incoming); + Value *V = std::get<1>(Incoming); + InvokeInst *II = dyn_cast<InvokeInst>(V); + if (!II) + continue; + if (II->getParent() != BB) + continue; + + // If we have a phi, and if it's directly in the predecessor, then we have + // a critical edge where we need to put the truncate. Since we can't + // split the edge in instcombine, we have to bail out. + return nullptr; + } + + // If the incoming value is a PHI node before a catchswitch, we cannot + // extract the value within that BB because we cannot insert any non-PHI + // instructions in the BB. + for (auto *Pred : PN->blocks()) + if (Pred->getFirstInsertionPt() == Pred->end()) + return nullptr; + + for (User *U : PN->users()) { + Instruction *UserI = cast<Instruction>(U); + + // If the user is a PHI, inspect its uses recursively. + if (PHINode *UserPN = dyn_cast<PHINode>(UserI)) { + if (PHIsInspected.insert(UserPN).second) + PHIsToSlice.push_back(UserPN); + continue; + } + + // Truncates are always ok. + if (isa<TruncInst>(UserI)) { + PHIUsers.push_back(PHIUsageRecord(PHIId, 0, UserI)); + continue; + } + + // Otherwise it must be a lshr which can only be used by one trunc. + if (UserI->getOpcode() != Instruction::LShr || + !UserI->hasOneUse() || !isa<TruncInst>(UserI->user_back()) || + !isa<ConstantInt>(UserI->getOperand(1))) + return nullptr; + + // Bail on out of range shifts. + unsigned SizeInBits = UserI->getType()->getScalarSizeInBits(); + if (cast<ConstantInt>(UserI->getOperand(1))->getValue().uge(SizeInBits)) + return nullptr; + + unsigned Shift = cast<ConstantInt>(UserI->getOperand(1))->getZExtValue(); + PHIUsers.push_back(PHIUsageRecord(PHIId, Shift, UserI->user_back())); + } + } + + // If we have no users, they must be all self uses, just nuke the PHI. + if (PHIUsers.empty()) + return replaceInstUsesWith(FirstPhi, PoisonValue::get(FirstPhi.getType())); + + // If this phi node is transformable, create new PHIs for all the pieces + // extracted out of it. First, sort the users by their offset and size. + array_pod_sort(PHIUsers.begin(), PHIUsers.end()); + + LLVM_DEBUG(dbgs() << "SLICING UP PHI: " << FirstPhi << '\n'; + for (unsigned I = 1; I != PHIsToSlice.size(); ++I) dbgs() + << "AND USER PHI #" << I << ": " << *PHIsToSlice[I] << '\n'); + + // PredValues - This is a temporary used when rewriting PHI nodes. It is + // hoisted out here to avoid construction/destruction thrashing. + DenseMap<BasicBlock*, Value*> PredValues; + + // ExtractedVals - Each new PHI we introduce is saved here so we don't + // introduce redundant PHIs. + DenseMap<LoweredPHIRecord, PHINode*> ExtractedVals; + + for (unsigned UserI = 0, UserE = PHIUsers.size(); UserI != UserE; ++UserI) { + unsigned PHIId = PHIUsers[UserI].PHIId; + PHINode *PN = PHIsToSlice[PHIId]; + unsigned Offset = PHIUsers[UserI].Shift; + Type *Ty = PHIUsers[UserI].Inst->getType(); + + PHINode *EltPHI; + + // If we've already lowered a user like this, reuse the previously lowered + // value. + if ((EltPHI = ExtractedVals[LoweredPHIRecord(PN, Offset, Ty)]) == nullptr) { + + // Otherwise, Create the new PHI node for this user. + EltPHI = PHINode::Create(Ty, PN->getNumIncomingValues(), + PN->getName()+".off"+Twine(Offset), PN); + assert(EltPHI->getType() != PN->getType() && + "Truncate didn't shrink phi?"); + + for (auto Incoming : zip(PN->blocks(), PN->incoming_values())) { + BasicBlock *Pred = std::get<0>(Incoming); + Value *InVal = std::get<1>(Incoming); + Value *&PredVal = PredValues[Pred]; + + // If we already have a value for this predecessor, reuse it. + if (PredVal) { + EltPHI->addIncoming(PredVal, Pred); + continue; + } + + // Handle the PHI self-reuse case. + if (InVal == PN) { + PredVal = EltPHI; + EltPHI->addIncoming(PredVal, Pred); + continue; + } + + if (PHINode *InPHI = dyn_cast<PHINode>(PN)) { + // If the incoming value was a PHI, and if it was one of the PHIs we + // already rewrote it, just use the lowered value. + if (Value *Res = ExtractedVals[LoweredPHIRecord(InPHI, Offset, Ty)]) { + PredVal = Res; + EltPHI->addIncoming(PredVal, Pred); + continue; + } + } + + // Otherwise, do an extract in the predecessor. + Builder.SetInsertPoint(Pred->getTerminator()); + Value *Res = InVal; + if (Offset) + Res = Builder.CreateLShr( + Res, ConstantInt::get(InVal->getType(), Offset), "extract"); + Res = Builder.CreateTrunc(Res, Ty, "extract.t"); + PredVal = Res; + EltPHI->addIncoming(Res, Pred); + + // If the incoming value was a PHI, and if it was one of the PHIs we are + // rewriting, we will ultimately delete the code we inserted. This + // means we need to revisit that PHI to make sure we extract out the + // needed piece. + if (PHINode *OldInVal = dyn_cast<PHINode>(InVal)) + if (PHIsInspected.count(OldInVal)) { + unsigned RefPHIId = + find(PHIsToSlice, OldInVal) - PHIsToSlice.begin(); + PHIUsers.push_back( + PHIUsageRecord(RefPHIId, Offset, cast<Instruction>(Res))); + ++UserE; + } + } + PredValues.clear(); + + LLVM_DEBUG(dbgs() << " Made element PHI for offset " << Offset << ": " + << *EltPHI << '\n'); + ExtractedVals[LoweredPHIRecord(PN, Offset, Ty)] = EltPHI; + } + + // Replace the use of this piece with the PHI node. + replaceInstUsesWith(*PHIUsers[UserI].Inst, EltPHI); + } + + // Replace all the remaining uses of the PHI nodes (self uses and the lshrs) + // with poison. + Value *Poison = PoisonValue::get(FirstPhi.getType()); + for (PHINode *PHI : drop_begin(PHIsToSlice)) + replaceInstUsesWith(*PHI, Poison); + return replaceInstUsesWith(FirstPhi, Poison); +} + +static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, + const DominatorTree &DT) { + // Simplify the following patterns: + // if (cond) + // / \ + // ... ... + // \ / + // phi [true] [false] + // and + // switch (cond) + // case v1: / \ case v2: + // ... ... + // \ / + // phi [v1] [v2] + // Make sure all inputs are constants. + if (!all_of(PN.operands(), [](Value *V) { return isa<ConstantInt>(V); })) + return nullptr; + + BasicBlock *BB = PN.getParent(); + // Do not bother with unreachable instructions. + if (!DT.isReachableFromEntry(BB)) + return nullptr; + + // Determine which value the condition of the idom has for which successor. + LLVMContext &Context = PN.getContext(); + auto *IDom = DT.getNode(BB)->getIDom()->getBlock(); + Value *Cond; + SmallDenseMap<ConstantInt *, BasicBlock *, 8> SuccForValue; + SmallDenseMap<BasicBlock *, unsigned, 8> SuccCount; + auto AddSucc = [&](ConstantInt *C, BasicBlock *Succ) { + SuccForValue[C] = Succ; + ++SuccCount[Succ]; + }; + if (auto *BI = dyn_cast<BranchInst>(IDom->getTerminator())) { + if (BI->isUnconditional()) + return nullptr; + + Cond = BI->getCondition(); + AddSucc(ConstantInt::getTrue(Context), BI->getSuccessor(0)); + AddSucc(ConstantInt::getFalse(Context), BI->getSuccessor(1)); + } else if (auto *SI = dyn_cast<SwitchInst>(IDom->getTerminator())) { + Cond = SI->getCondition(); + ++SuccCount[SI->getDefaultDest()]; + for (auto Case : SI->cases()) + AddSucc(Case.getCaseValue(), Case.getCaseSuccessor()); + } else { + return nullptr; + } + + if (Cond->getType() != PN.getType()) + return nullptr; + + // Check that edges outgoing from the idom's terminators dominate respective + // inputs of the Phi. + Optional<bool> Invert; + for (auto Pair : zip(PN.incoming_values(), PN.blocks())) { + auto *Input = cast<ConstantInt>(std::get<0>(Pair)); + BasicBlock *Pred = std::get<1>(Pair); + auto IsCorrectInput = [&](ConstantInt *Input) { + // The input needs to be dominated by the corresponding edge of the idom. + // This edge cannot be a multi-edge, as that would imply that multiple + // different condition values follow the same edge. + auto It = SuccForValue.find(Input); + return It != SuccForValue.end() && SuccCount[It->second] == 1 && + DT.dominates(BasicBlockEdge(IDom, It->second), + BasicBlockEdge(Pred, BB)); + }; + + // Depending on the constant, the condition may need to be inverted. + bool NeedsInvert; + if (IsCorrectInput(Input)) + NeedsInvert = false; + else if (IsCorrectInput(cast<ConstantInt>(ConstantExpr::getNot(Input)))) + NeedsInvert = true; + else + return nullptr; + + // Make sure the inversion requirement is always the same. + if (Invert && *Invert != NeedsInvert) + return nullptr; + + Invert = NeedsInvert; + } + + if (!*Invert) + return Cond; + + // This Phi is actually opposite to branching condition of IDom. We invert + // the condition that will potentially open up some opportunities for + // sinking. + auto InsertPt = BB->getFirstInsertionPt(); + if (InsertPt != BB->end()) { + Self.Builder.SetInsertPoint(&*InsertPt); + return Self.Builder.CreateNot(Cond); + } + + return nullptr; +} + +// PHINode simplification +// +Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { + if (Value *V = simplifyInstruction(&PN, SQ.getWithInstruction(&PN))) + return replaceInstUsesWith(PN, V); + + if (Instruction *Result = foldPHIArgZextsIntoPHI(PN)) + return Result; + + if (Instruction *Result = foldPHIArgIntToPtrToPHI(PN)) + return Result; + + // If all PHI operands are the same operation, pull them through the PHI, + // reducing code size. + if (isa<Instruction>(PN.getIncomingValue(0)) && + isa<Instruction>(PN.getIncomingValue(1)) && + cast<Instruction>(PN.getIncomingValue(0))->getOpcode() == + cast<Instruction>(PN.getIncomingValue(1))->getOpcode() && + PN.getIncomingValue(0)->hasOneUser()) + if (Instruction *Result = foldPHIArgOpIntoPHI(PN)) + return Result; + + // If the incoming values are pointer casts of the same original value, + // replace the phi with a single cast iff we can insert a non-PHI instruction. + if (PN.getType()->isPointerTy() && + PN.getParent()->getFirstInsertionPt() != PN.getParent()->end()) { + Value *IV0 = PN.getIncomingValue(0); + Value *IV0Stripped = IV0->stripPointerCasts(); + // Set to keep track of values known to be equal to IV0Stripped after + // stripping pointer casts. + SmallPtrSet<Value *, 4> CheckedIVs; + CheckedIVs.insert(IV0); + if (IV0 != IV0Stripped && + all_of(PN.incoming_values(), [&CheckedIVs, IV0Stripped](Value *IV) { + return !CheckedIVs.insert(IV).second || + IV0Stripped == IV->stripPointerCasts(); + })) { + return CastInst::CreatePointerCast(IV0Stripped, PN.getType()); + } + } + + // If this is a trivial cycle in the PHI node graph, remove it. Basically, if + // this PHI only has a single use (a PHI), and if that PHI only has one use (a + // PHI)... break the cycle. + if (PN.hasOneUse()) { + if (Instruction *Result = foldIntegerTypedPHI(PN)) + return Result; + + Instruction *PHIUser = cast<Instruction>(PN.user_back()); + if (PHINode *PU = dyn_cast<PHINode>(PHIUser)) { + SmallPtrSet<PHINode*, 16> PotentiallyDeadPHIs; + PotentiallyDeadPHIs.insert(&PN); + if (isDeadPHICycle(PU, PotentiallyDeadPHIs)) + return replaceInstUsesWith(PN, PoisonValue::get(PN.getType())); + } + + // If this phi has a single use, and if that use just computes a value for + // the next iteration of a loop, delete the phi. This occurs with unused + // induction variables, e.g. "for (int j = 0; ; ++j);". Detecting this + // common case here is good because the only other things that catch this + // are induction variable analysis (sometimes) and ADCE, which is only run + // late. + if (PHIUser->hasOneUse() && + (isa<BinaryOperator>(PHIUser) || isa<GetElementPtrInst>(PHIUser)) && + PHIUser->user_back() == &PN) { + return replaceInstUsesWith(PN, PoisonValue::get(PN.getType())); + } + // When a PHI is used only to be compared with zero, it is safe to replace + // an incoming value proved as known nonzero with any non-zero constant. + // For example, in the code below, the incoming value %v can be replaced + // with any non-zero constant based on the fact that the PHI is only used to + // be compared with zero and %v is a known non-zero value: + // %v = select %cond, 1, 2 + // %p = phi [%v, BB] ... + // icmp eq, %p, 0 + auto *CmpInst = dyn_cast<ICmpInst>(PHIUser); + // FIXME: To be simple, handle only integer type for now. + if (CmpInst && isa<IntegerType>(PN.getType()) && CmpInst->isEquality() && + match(CmpInst->getOperand(1), m_Zero())) { + ConstantInt *NonZeroConst = nullptr; + bool MadeChange = false; + for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) { + Instruction *CtxI = PN.getIncomingBlock(I)->getTerminator(); + Value *VA = PN.getIncomingValue(I); + if (isKnownNonZero(VA, DL, 0, &AC, CtxI, &DT)) { + if (!NonZeroConst) + NonZeroConst = getAnyNonZeroConstInt(PN); + + if (NonZeroConst != VA) { + replaceOperand(PN, I, NonZeroConst); + MadeChange = true; + } + } + } + if (MadeChange) + return &PN; + } + } + + // We sometimes end up with phi cycles that non-obviously end up being the + // same value, for example: + // z = some value; x = phi (y, z); y = phi (x, z) + // where the phi nodes don't necessarily need to be in the same block. Do a + // quick check to see if the PHI node only contains a single non-phi value, if + // so, scan to see if the phi cycle is actually equal to that value. + { + unsigned InValNo = 0, NumIncomingVals = PN.getNumIncomingValues(); + // Scan for the first non-phi operand. + while (InValNo != NumIncomingVals && + isa<PHINode>(PN.getIncomingValue(InValNo))) + ++InValNo; + + if (InValNo != NumIncomingVals) { + Value *NonPhiInVal = PN.getIncomingValue(InValNo); + + // Scan the rest of the operands to see if there are any conflicts, if so + // there is no need to recursively scan other phis. + for (++InValNo; InValNo != NumIncomingVals; ++InValNo) { + Value *OpVal = PN.getIncomingValue(InValNo); + if (OpVal != NonPhiInVal && !isa<PHINode>(OpVal)) + break; + } + + // If we scanned over all operands, then we have one unique value plus + // phi values. Scan PHI nodes to see if they all merge in each other or + // the value. + if (InValNo == NumIncomingVals) { + SmallPtrSet<PHINode*, 16> ValueEqualPHIs; + if (PHIsEqualValue(&PN, NonPhiInVal, ValueEqualPHIs)) + return replaceInstUsesWith(PN, NonPhiInVal); + } + } + } + + // If there are multiple PHIs, sort their operands so that they all list + // the blocks in the same order. This will help identical PHIs be eliminated + // by other passes. Other passes shouldn't depend on this for correctness + // however. + PHINode *FirstPN = cast<PHINode>(PN.getParent()->begin()); + if (&PN != FirstPN) + for (unsigned I = 0, E = FirstPN->getNumIncomingValues(); I != E; ++I) { + BasicBlock *BBA = PN.getIncomingBlock(I); + BasicBlock *BBB = FirstPN->getIncomingBlock(I); + if (BBA != BBB) { + Value *VA = PN.getIncomingValue(I); + unsigned J = PN.getBasicBlockIndex(BBB); + Value *VB = PN.getIncomingValue(J); + PN.setIncomingBlock(I, BBB); + PN.setIncomingValue(I, VB); + PN.setIncomingBlock(J, BBA); + PN.setIncomingValue(J, VA); + // NOTE: Instcombine normally would want us to "return &PN" if we + // modified any of the operands of an instruction. However, since we + // aren't adding or removing uses (just rearranging them) we don't do + // this in this case. + } + } + + // Is there an identical PHI node in this basic block? + for (PHINode &IdenticalPN : PN.getParent()->phis()) { + // Ignore the PHI node itself. + if (&IdenticalPN == &PN) + continue; + // Note that even though we've just canonicalized this PHI, due to the + // worklist visitation order, there are no guarantess that *every* PHI + // has been canonicalized, so we can't just compare operands ranges. + if (!PN.isIdenticalToWhenDefined(&IdenticalPN)) + continue; + // Just use that PHI instead then. + ++NumPHICSEs; + return replaceInstUsesWith(PN, &IdenticalPN); + } + + // If this is an integer PHI and we know that it has an illegal type, see if + // it is only used by trunc or trunc(lshr) operations. If so, we split the + // PHI into the various pieces being extracted. This sort of thing is + // introduced when SROA promotes an aggregate to a single large integer type. + if (PN.getType()->isIntegerTy() && + !DL.isLegalInteger(PN.getType()->getPrimitiveSizeInBits())) + if (Instruction *Res = SliceUpIllegalIntegerPHI(PN)) + return Res; + + // Ultimately, try to replace this Phi with a dominating condition. + if (auto *V = simplifyUsingControlFlow(*this, PN, DT)) + return replaceInstUsesWith(PN, V); + + return nullptr; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp new file mode 100644 index 000000000000..ad96a5f475f1 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -0,0 +1,3202 @@ +//===- InstCombineSelect.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the visitSelect function. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CmpInstAnalysis.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/OverflowInstAnalysis.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +#include <cassert> +#include <utility> + +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + +using namespace llvm; +using namespace PatternMatch; + + +/// Replace a select operand based on an equality comparison with the identity +/// constant of a binop. +static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, + const TargetLibraryInfo &TLI, + InstCombinerImpl &IC) { + // The select condition must be an equality compare with a constant operand. + Value *X; + Constant *C; + CmpInst::Predicate Pred; + if (!match(Sel.getCondition(), m_Cmp(Pred, m_Value(X), m_Constant(C)))) + return nullptr; + + bool IsEq; + if (ICmpInst::isEquality(Pred)) + IsEq = Pred == ICmpInst::ICMP_EQ; + else if (Pred == FCmpInst::FCMP_OEQ) + IsEq = true; + else if (Pred == FCmpInst::FCMP_UNE) + IsEq = false; + else + return nullptr; + + // A select operand must be a binop. + BinaryOperator *BO; + if (!match(Sel.getOperand(IsEq ? 1 : 2), m_BinOp(BO))) + return nullptr; + + // The compare constant must be the identity constant for that binop. + // If this a floating-point compare with 0.0, any zero constant will do. + Type *Ty = BO->getType(); + Constant *IdC = ConstantExpr::getBinOpIdentity(BO->getOpcode(), Ty, true); + if (IdC != C) { + if (!IdC || !CmpInst::isFPPredicate(Pred)) + return nullptr; + if (!match(IdC, m_AnyZeroFP()) || !match(C, m_AnyZeroFP())) + return nullptr; + } + + // Last, match the compare variable operand with a binop operand. + Value *Y; + if (!BO->isCommutative() && !match(BO, m_BinOp(m_Value(Y), m_Specific(X)))) + return nullptr; + if (!match(BO, m_c_BinOp(m_Value(Y), m_Specific(X)))) + return nullptr; + + // +0.0 compares equal to -0.0, and so it does not behave as required for this + // transform. Bail out if we can not exclude that possibility. + if (isa<FPMathOperator>(BO)) + if (!BO->hasNoSignedZeros() && !CannotBeNegativeZero(Y, &TLI)) + return nullptr; + + // BO = binop Y, X + // S = { select (cmp eq X, C), BO, ? } or { select (cmp ne X, C), ?, BO } + // => + // S = { select (cmp eq X, C), Y, ? } or { select (cmp ne X, C), ?, Y } + return IC.replaceOperand(Sel, IsEq ? 1 : 2, Y); +} + +/// This folds: +/// select (icmp eq (and X, C1)), TC, FC +/// iff C1 is a power 2 and the difference between TC and FC is a power-of-2. +/// To something like: +/// (shr (and (X, C1)), (log2(C1) - log2(TC-FC))) + FC +/// Or: +/// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC +/// With some variations depending if FC is larger than TC, or the shift +/// isn't needed, or the bit widths don't match. +static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, + InstCombiner::BuilderTy &Builder) { + const APInt *SelTC, *SelFC; + if (!match(Sel.getTrueValue(), m_APInt(SelTC)) || + !match(Sel.getFalseValue(), m_APInt(SelFC))) + return nullptr; + + // If this is a vector select, we need a vector compare. + Type *SelType = Sel.getType(); + if (SelType->isVectorTy() != Cmp->getType()->isVectorTy()) + return nullptr; + + Value *V; + APInt AndMask; + bool CreateAnd = false; + ICmpInst::Predicate Pred = Cmp->getPredicate(); + if (ICmpInst::isEquality(Pred)) { + if (!match(Cmp->getOperand(1), m_Zero())) + return nullptr; + + V = Cmp->getOperand(0); + const APInt *AndRHS; + if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) + return nullptr; + + AndMask = *AndRHS; + } else if (decomposeBitTestICmp(Cmp->getOperand(0), Cmp->getOperand(1), + Pred, V, AndMask)) { + assert(ICmpInst::isEquality(Pred) && "Not equality test?"); + if (!AndMask.isPowerOf2()) + return nullptr; + + CreateAnd = true; + } else { + return nullptr; + } + + // In general, when both constants are non-zero, we would need an offset to + // replace the select. This would require more instructions than we started + // with. But there's one special-case that we handle here because it can + // simplify/reduce the instructions. + APInt TC = *SelTC; + APInt FC = *SelFC; + if (!TC.isZero() && !FC.isZero()) { + // If the select constants differ by exactly one bit and that's the same + // bit that is masked and checked by the select condition, the select can + // be replaced by bitwise logic to set/clear one bit of the constant result. + if (TC.getBitWidth() != AndMask.getBitWidth() || (TC ^ FC) != AndMask) + return nullptr; + if (CreateAnd) { + // If we have to create an 'and', then we must kill the cmp to not + // increase the instruction count. + if (!Cmp->hasOneUse()) + return nullptr; + V = Builder.CreateAnd(V, ConstantInt::get(SelType, AndMask)); + } + bool ExtraBitInTC = TC.ugt(FC); + if (Pred == ICmpInst::ICMP_EQ) { + // If the masked bit in V is clear, clear or set the bit in the result: + // (V & AndMaskC) == 0 ? TC : FC --> (V & AndMaskC) ^ TC + // (V & AndMaskC) == 0 ? TC : FC --> (V & AndMaskC) | TC + Constant *C = ConstantInt::get(SelType, TC); + return ExtraBitInTC ? Builder.CreateXor(V, C) : Builder.CreateOr(V, C); + } + if (Pred == ICmpInst::ICMP_NE) { + // If the masked bit in V is set, set or clear the bit in the result: + // (V & AndMaskC) != 0 ? TC : FC --> (V & AndMaskC) | FC + // (V & AndMaskC) != 0 ? TC : FC --> (V & AndMaskC) ^ FC + Constant *C = ConstantInt::get(SelType, FC); + return ExtraBitInTC ? Builder.CreateOr(V, C) : Builder.CreateXor(V, C); + } + llvm_unreachable("Only expecting equality predicates"); + } + + // Make sure one of the select arms is a power-of-2. + if (!TC.isPowerOf2() && !FC.isPowerOf2()) + return nullptr; + + // Determine which shift is needed to transform result of the 'and' into the + // desired result. + const APInt &ValC = !TC.isZero() ? TC : FC; + unsigned ValZeros = ValC.logBase2(); + unsigned AndZeros = AndMask.logBase2(); + + // Insert the 'and' instruction on the input to the truncate. + if (CreateAnd) + V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask)); + + // If types don't match, we can still convert the select by introducing a zext + // or a trunc of the 'and'. + if (ValZeros > AndZeros) { + V = Builder.CreateZExtOrTrunc(V, SelType); + V = Builder.CreateShl(V, ValZeros - AndZeros); + } else if (ValZeros < AndZeros) { + V = Builder.CreateLShr(V, AndZeros - ValZeros); + V = Builder.CreateZExtOrTrunc(V, SelType); + } else { + V = Builder.CreateZExtOrTrunc(V, SelType); + } + + // Okay, now we know that everything is set up, we just don't know whether we + // have a icmp_ne or icmp_eq and whether the true or false val is the zero. + bool ShouldNotVal = !TC.isZero(); + ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; + if (ShouldNotVal) + V = Builder.CreateXor(V, ValC); + + return V; +} + +/// We want to turn code that looks like this: +/// %C = or %A, %B +/// %D = select %cond, %C, %A +/// into: +/// %C = select %cond, %B, 0 +/// %D = or %A, %C +/// +/// Assuming that the specified instruction is an operand to the select, return +/// a bitmask indicating which operands of this instruction are foldable if they +/// equal the other incoming value of the select. +static unsigned getSelectFoldableOperands(BinaryOperator *I) { + switch (I->getOpcode()) { + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + return 3; // Can fold through either operand. + case Instruction::Sub: // Can only fold on the amount subtracted. + case Instruction::FSub: + case Instruction::FDiv: // Can only fold on the divisor amount. + case Instruction::Shl: // Can only fold on the shift amount. + case Instruction::LShr: + case Instruction::AShr: + return 1; + default: + return 0; // Cannot fold + } +} + +/// We have (select c, TI, FI), and we know that TI and FI have the same opcode. +Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, + Instruction *FI) { + // Don't break up min/max patterns. The hasOneUse checks below prevent that + // for most cases, but vector min/max with bitcasts can be transformed. If the + // one-use restrictions are eased for other patterns, we still don't want to + // obfuscate min/max. + if ((match(&SI, m_SMin(m_Value(), m_Value())) || + match(&SI, m_SMax(m_Value(), m_Value())) || + match(&SI, m_UMin(m_Value(), m_Value())) || + match(&SI, m_UMax(m_Value(), m_Value())))) + return nullptr; + + // If this is a cast from the same type, merge. + Value *Cond = SI.getCondition(); + Type *CondTy = Cond->getType(); + if (TI->getNumOperands() == 1 && TI->isCast()) { + Type *FIOpndTy = FI->getOperand(0)->getType(); + if (TI->getOperand(0)->getType() != FIOpndTy) + return nullptr; + + // The select condition may be a vector. We may only change the operand + // type if the vector width remains the same (and matches the condition). + if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) { + if (!FIOpndTy->isVectorTy() || + CondVTy->getElementCount() != + cast<VectorType>(FIOpndTy)->getElementCount()) + return nullptr; + + // TODO: If the backend knew how to deal with casts better, we could + // remove this limitation. For now, there's too much potential to create + // worse codegen by promoting the select ahead of size-altering casts + // (PR28160). + // + // Note that ValueTracking's matchSelectPattern() looks through casts + // without checking 'hasOneUse' when it matches min/max patterns, so this + // transform may end up happening anyway. + if (TI->getOpcode() != Instruction::BitCast && + (!TI->hasOneUse() || !FI->hasOneUse())) + return nullptr; + } else if (!TI->hasOneUse() || !FI->hasOneUse()) { + // TODO: The one-use restrictions for a scalar select could be eased if + // the fold of a select in visitLoadInst() was enhanced to match a pattern + // that includes a cast. + return nullptr; + } + + // Fold this by inserting a select from the input values. + Value *NewSI = + Builder.CreateSelect(Cond, TI->getOperand(0), FI->getOperand(0), + SI.getName() + ".v", &SI); + return CastInst::Create(Instruction::CastOps(TI->getOpcode()), NewSI, + TI->getType()); + } + + // Cond ? -X : -Y --> -(Cond ? X : Y) + Value *X, *Y; + if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y))) && + (TI->hasOneUse() || FI->hasOneUse())) { + // Intersect FMF from the fneg instructions and union those with the select. + FastMathFlags FMF = TI->getFastMathFlags(); + FMF &= FI->getFastMathFlags(); + FMF |= SI.getFastMathFlags(); + Value *NewSel = Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI); + if (auto *NewSelI = dyn_cast<Instruction>(NewSel)) + NewSelI->setFastMathFlags(FMF); + Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel); + NewFNeg->setFastMathFlags(FMF); + return NewFNeg; + } + + // Min/max intrinsic with a common operand can have the common operand pulled + // after the select. This is the same transform as below for binops, but + // specialized for intrinsic matching and without the restrictive uses clause. + auto *TII = dyn_cast<IntrinsicInst>(TI); + auto *FII = dyn_cast<IntrinsicInst>(FI); + if (TII && FII && TII->getIntrinsicID() == FII->getIntrinsicID() && + (TII->hasOneUse() || FII->hasOneUse())) { + Value *T0, *T1, *F0, *F1; + if (match(TII, m_MaxOrMin(m_Value(T0), m_Value(T1))) && + match(FII, m_MaxOrMin(m_Value(F0), m_Value(F1)))) { + if (T0 == F0) { + Value *NewSel = Builder.CreateSelect(Cond, T1, F1, "minmaxop", &SI); + return CallInst::Create(TII->getCalledFunction(), {NewSel, T0}); + } + if (T0 == F1) { + Value *NewSel = Builder.CreateSelect(Cond, T1, F0, "minmaxop", &SI); + return CallInst::Create(TII->getCalledFunction(), {NewSel, T0}); + } + if (T1 == F0) { + Value *NewSel = Builder.CreateSelect(Cond, T0, F1, "minmaxop", &SI); + return CallInst::Create(TII->getCalledFunction(), {NewSel, T1}); + } + if (T1 == F1) { + Value *NewSel = Builder.CreateSelect(Cond, T0, F0, "minmaxop", &SI); + return CallInst::Create(TII->getCalledFunction(), {NewSel, T1}); + } + } + } + + // Only handle binary operators (including two-operand getelementptr) with + // one-use here. As with the cast case above, it may be possible to relax the + // one-use constraint, but that needs be examined carefully since it may not + // reduce the total number of instructions. + if (TI->getNumOperands() != 2 || FI->getNumOperands() != 2 || + !TI->isSameOperationAs(FI) || + (!isa<BinaryOperator>(TI) && !isa<GetElementPtrInst>(TI)) || + !TI->hasOneUse() || !FI->hasOneUse()) + return nullptr; + + // Figure out if the operations have any operands in common. + Value *MatchOp, *OtherOpT, *OtherOpF; + bool MatchIsOpZero; + if (TI->getOperand(0) == FI->getOperand(0)) { + MatchOp = TI->getOperand(0); + OtherOpT = TI->getOperand(1); + OtherOpF = FI->getOperand(1); + MatchIsOpZero = true; + } else if (TI->getOperand(1) == FI->getOperand(1)) { + MatchOp = TI->getOperand(1); + OtherOpT = TI->getOperand(0); + OtherOpF = FI->getOperand(0); + MatchIsOpZero = false; + } else if (!TI->isCommutative()) { + return nullptr; + } else if (TI->getOperand(0) == FI->getOperand(1)) { + MatchOp = TI->getOperand(0); + OtherOpT = TI->getOperand(1); + OtherOpF = FI->getOperand(0); + MatchIsOpZero = true; + } else if (TI->getOperand(1) == FI->getOperand(0)) { + MatchOp = TI->getOperand(1); + OtherOpT = TI->getOperand(0); + OtherOpF = FI->getOperand(1); + MatchIsOpZero = true; + } else { + return nullptr; + } + + // If the select condition is a vector, the operands of the original select's + // operands also must be vectors. This may not be the case for getelementptr + // for example. + if (CondTy->isVectorTy() && (!OtherOpT->getType()->isVectorTy() || + !OtherOpF->getType()->isVectorTy())) + return nullptr; + + // If we reach here, they do have operations in common. + Value *NewSI = Builder.CreateSelect(Cond, OtherOpT, OtherOpF, + SI.getName() + ".v", &SI); + Value *Op0 = MatchIsOpZero ? MatchOp : NewSI; + Value *Op1 = MatchIsOpZero ? NewSI : MatchOp; + if (auto *BO = dyn_cast<BinaryOperator>(TI)) { + BinaryOperator *NewBO = BinaryOperator::Create(BO->getOpcode(), Op0, Op1); + NewBO->copyIRFlags(TI); + NewBO->andIRFlags(FI); + return NewBO; + } + if (auto *TGEP = dyn_cast<GetElementPtrInst>(TI)) { + auto *FGEP = cast<GetElementPtrInst>(FI); + Type *ElementType = TGEP->getResultElementType(); + return TGEP->isInBounds() && FGEP->isInBounds() + ? GetElementPtrInst::CreateInBounds(ElementType, Op0, {Op1}) + : GetElementPtrInst::Create(ElementType, Op0, {Op1}); + } + llvm_unreachable("Expected BinaryOperator or GEP"); + return nullptr; +} + +static bool isSelect01(const APInt &C1I, const APInt &C2I) { + if (!C1I.isZero() && !C2I.isZero()) // One side must be zero. + return false; + return C1I.isOne() || C1I.isAllOnes() || C2I.isOne() || C2I.isAllOnes(); +} + +/// Try to fold the select into one of the operands to allow further +/// optimization. +Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, + Value *FalseVal) { + // See the comment above GetSelectFoldableOperands for a description of the + // transformation we are doing here. + auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal, + Value *FalseVal, + bool Swapped) -> Instruction * { + if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) { + if (TVI->hasOneUse() && !isa<Constant>(FalseVal)) { + if (unsigned SFO = getSelectFoldableOperands(TVI)) { + unsigned OpToFold = 0; + if ((SFO & 1) && FalseVal == TVI->getOperand(0)) + OpToFold = 1; + else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) + OpToFold = 2; + + if (OpToFold) { + FastMathFlags FMF; + // TODO: We probably ought to revisit cases where the select and FP + // instructions have different flags and add tests to ensure the + // behaviour is correct. + if (isa<FPMathOperator>(&SI)) + FMF = SI.getFastMathFlags(); + Constant *C = ConstantExpr::getBinOpIdentity( + TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros()); + Value *OOp = TVI->getOperand(2 - OpToFold); + // Avoid creating select between 2 constants unless it's selecting + // between 0, 1 and -1. + const APInt *OOpC; + bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); + if (!isa<Constant>(OOp) || + (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { + Value *NewSel = Builder.CreateSelect( + SI.getCondition(), Swapped ? C : OOp, Swapped ? OOp : C); + if (isa<FPMathOperator>(&SI)) + cast<Instruction>(NewSel)->setFastMathFlags(FMF); + NewSel->takeName(TVI); + BinaryOperator *BO = + BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); + BO->copyIRFlags(TVI); + return BO; + } + } + } + } + } + return nullptr; + }; + + if (Instruction *R = TryFoldSelectIntoOp(SI, TrueVal, FalseVal, false)) + return R; + + if (Instruction *R = TryFoldSelectIntoOp(SI, FalseVal, TrueVal, true)) + return R; + + return nullptr; +} + +/// We want to turn: +/// (select (icmp eq (and X, Y), 0), (and (lshr X, Z), 1), 1) +/// into: +/// zext (icmp ne i32 (and X, (or Y, (shl 1, Z))), 0) +/// Note: +/// Z may be 0 if lshr is missing. +/// Worst-case scenario is that we will replace 5 instructions with 5 different +/// instructions, but we got rid of select. +static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp, + Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder) { + if (!(Cmp->hasOneUse() && Cmp->getOperand(0)->hasOneUse() && + Cmp->getPredicate() == ICmpInst::ICMP_EQ && + match(Cmp->getOperand(1), m_Zero()) && match(FVal, m_One()))) + return nullptr; + + // The TrueVal has general form of: and %B, 1 + Value *B; + if (!match(TVal, m_OneUse(m_And(m_Value(B), m_One())))) + return nullptr; + + // Where %B may be optionally shifted: lshr %X, %Z. + Value *X, *Z; + const bool HasShift = match(B, m_OneUse(m_LShr(m_Value(X), m_Value(Z)))); + + // The shift must be valid. + // TODO: This restricts the fold to constant shift amounts. Is there a way to + // handle variable shifts safely? PR47012 + if (HasShift && + !match(Z, m_SpecificInt_ICMP(CmpInst::ICMP_ULT, + APInt(SelType->getScalarSizeInBits(), + SelType->getScalarSizeInBits())))) + return nullptr; + + if (!HasShift) + X = B; + + Value *Y; + if (!match(Cmp->getOperand(0), m_c_And(m_Specific(X), m_Value(Y)))) + return nullptr; + + // ((X & Y) == 0) ? ((X >> Z) & 1) : 1 --> (X & (Y | (1 << Z))) != 0 + // ((X & Y) == 0) ? (X & 1) : 1 --> (X & (Y | 1)) != 0 + Constant *One = ConstantInt::get(SelType, 1); + Value *MaskB = HasShift ? Builder.CreateShl(One, Z) : One; + Value *FullMask = Builder.CreateOr(Y, MaskB); + Value *MaskedX = Builder.CreateAnd(X, FullMask); + Value *ICmpNeZero = Builder.CreateIsNotNull(MaskedX); + return new ZExtInst(ICmpNeZero, SelType); +} + +/// We want to turn: +/// (select (icmp sgt x, C), lshr (X, Y), ashr (X, Y)); iff C s>= -1 +/// (select (icmp slt x, C), ashr (X, Y), lshr (X, Y)); iff C s>= 0 +/// into: +/// ashr (X, Y) +static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal, + Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred = IC->getPredicate(); + Value *CmpLHS = IC->getOperand(0); + Value *CmpRHS = IC->getOperand(1); + if (!CmpRHS->getType()->isIntOrIntVectorTy()) + return nullptr; + + Value *X, *Y; + unsigned Bitwidth = CmpRHS->getType()->getScalarSizeInBits(); + if ((Pred != ICmpInst::ICMP_SGT || + !match(CmpRHS, + m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, -1)))) && + (Pred != ICmpInst::ICMP_SLT || + !match(CmpRHS, + m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, 0))))) + return nullptr; + + // Canonicalize so that ashr is in FalseVal. + if (Pred == ICmpInst::ICMP_SLT) + std::swap(TrueVal, FalseVal); + + if (match(TrueVal, m_LShr(m_Value(X), m_Value(Y))) && + match(FalseVal, m_AShr(m_Specific(X), m_Specific(Y))) && + match(CmpLHS, m_Specific(X))) { + const auto *Ashr = cast<Instruction>(FalseVal); + // if lshr is not exact and ashr is, this new ashr must not be exact. + bool IsExact = Ashr->isExact() && cast<Instruction>(TrueVal)->isExact(); + return Builder.CreateAShr(X, Y, IC->getName(), IsExact); + } + + return nullptr; +} + +/// We want to turn: +/// (select (icmp eq (and X, C1), 0), Y, (or Y, C2)) +/// into: +/// (or (shl (and X, C1), C3), Y) +/// iff: +/// C1 and C2 are both powers of 2 +/// where: +/// C3 = Log(C2) - Log(C1) +/// +/// This transform handles cases where: +/// 1. The icmp predicate is inverted +/// 2. The select operands are reversed +/// 3. The magnitude of C2 and C1 are flipped +static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, + Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + // Only handle integer compares. Also, if this is a vector select, we need a + // vector compare. + if (!TrueVal->getType()->isIntOrIntVectorTy() || + TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy()) + return nullptr; + + Value *CmpLHS = IC->getOperand(0); + Value *CmpRHS = IC->getOperand(1); + + Value *V; + unsigned C1Log; + bool IsEqualZero; + bool NeedAnd = false; + if (IC->isEquality()) { + if (!match(CmpRHS, m_Zero())) + return nullptr; + + const APInt *C1; + if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1)))) + return nullptr; + + V = CmpLHS; + C1Log = C1->logBase2(); + IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ; + } else if (IC->getPredicate() == ICmpInst::ICMP_SLT || + IC->getPredicate() == ICmpInst::ICMP_SGT) { + // We also need to recognize (icmp slt (trunc (X)), 0) and + // (icmp sgt (trunc (X)), -1). + IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_SGT; + if ((IsEqualZero && !match(CmpRHS, m_AllOnes())) || + (!IsEqualZero && !match(CmpRHS, m_Zero()))) + return nullptr; + + if (!match(CmpLHS, m_OneUse(m_Trunc(m_Value(V))))) + return nullptr; + + C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1; + NeedAnd = true; + } else { + return nullptr; + } + + const APInt *C2; + bool OrOnTrueVal = false; + bool OrOnFalseVal = match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2))); + if (!OrOnFalseVal) + OrOnTrueVal = match(TrueVal, m_Or(m_Specific(FalseVal), m_Power2(C2))); + + if (!OrOnFalseVal && !OrOnTrueVal) + return nullptr; + + Value *Y = OrOnFalseVal ? TrueVal : FalseVal; + + unsigned C2Log = C2->logBase2(); + + bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal); + bool NeedShift = C1Log != C2Log; + bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() != + V->getType()->getScalarSizeInBits(); + + // Make sure we don't create more instructions than we save. + Value *Or = OrOnFalseVal ? FalseVal : TrueVal; + if ((NeedShift + NeedXor + NeedZExtTrunc) > + (IC->hasOneUse() + Or->hasOneUse())) + return nullptr; + + if (NeedAnd) { + // Insert the AND instruction on the input to the truncate. + APInt C1 = APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), C1Log); + V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), C1)); + } + + if (C2Log > C1Log) { + V = Builder.CreateZExtOrTrunc(V, Y->getType()); + V = Builder.CreateShl(V, C2Log - C1Log); + } else if (C1Log > C2Log) { + V = Builder.CreateLShr(V, C1Log - C2Log); + V = Builder.CreateZExtOrTrunc(V, Y->getType()); + } else + V = Builder.CreateZExtOrTrunc(V, Y->getType()); + + if (NeedXor) + V = Builder.CreateXor(V, *C2); + + return Builder.CreateOr(V, Y); +} + +/// Canonicalize a set or clear of a masked set of constant bits to +/// select-of-constants form. +static Instruction *foldSetClearBits(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + Value *Cond = Sel.getCondition(); + Value *T = Sel.getTrueValue(); + Value *F = Sel.getFalseValue(); + Type *Ty = Sel.getType(); + Value *X; + const APInt *NotC, *C; + + // Cond ? (X & ~C) : (X | C) --> (X & ~C) | (Cond ? 0 : C) + if (match(T, m_And(m_Value(X), m_APInt(NotC))) && + match(F, m_OneUse(m_Or(m_Specific(X), m_APInt(C)))) && *NotC == ~(*C)) { + Constant *Zero = ConstantInt::getNullValue(Ty); + Constant *OrC = ConstantInt::get(Ty, *C); + Value *NewSel = Builder.CreateSelect(Cond, Zero, OrC, "masksel", &Sel); + return BinaryOperator::CreateOr(T, NewSel); + } + + // Cond ? (X | C) : (X & ~C) --> (X & ~C) | (Cond ? C : 0) + if (match(F, m_And(m_Value(X), m_APInt(NotC))) && + match(T, m_OneUse(m_Or(m_Specific(X), m_APInt(C)))) && *NotC == ~(*C)) { + Constant *Zero = ConstantInt::getNullValue(Ty); + Constant *OrC = ConstantInt::get(Ty, *C); + Value *NewSel = Builder.CreateSelect(Cond, OrC, Zero, "masksel", &Sel); + return BinaryOperator::CreateOr(F, NewSel); + } + + return nullptr; +} + +// select (x == 0), 0, x * y --> freeze(y) * x +// select (y == 0), 0, x * y --> freeze(x) * y +// select (x == 0), undef, x * y --> freeze(y) * x +// select (x == undef), 0, x * y --> freeze(y) * x +// Usage of mul instead of 0 will make the result more poisonous, +// so the operand that was not checked in the condition should be frozen. +// The latter folding is applied only when a constant compared with x is +// is a vector consisting of 0 and undefs. If a constant compared with x +// is a scalar undefined value or undefined vector then an expression +// should be already folded into a constant. +static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) { + auto *CondVal = SI.getCondition(); + auto *TrueVal = SI.getTrueValue(); + auto *FalseVal = SI.getFalseValue(); + Value *X, *Y; + ICmpInst::Predicate Predicate; + + // Assuming that constant compared with zero is not undef (but it may be + // a vector with some undef elements). Otherwise (when a constant is undef) + // the select expression should be already simplified. + if (!match(CondVal, m_ICmp(Predicate, m_Value(X), m_Zero())) || + !ICmpInst::isEquality(Predicate)) + return nullptr; + + if (Predicate == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + // Check that TrueVal is a constant instead of matching it with m_Zero() + // to handle the case when it is a scalar undef value or a vector containing + // non-zero elements that are masked by undef elements in the compare + // constant. + auto *TrueValC = dyn_cast<Constant>(TrueVal); + if (TrueValC == nullptr || + !match(FalseVal, m_c_Mul(m_Specific(X), m_Value(Y))) || + !isa<Instruction>(FalseVal)) + return nullptr; + + auto *ZeroC = cast<Constant>(cast<Instruction>(CondVal)->getOperand(1)); + auto *MergedC = Constant::mergeUndefsWith(TrueValC, ZeroC); + // If X is compared with 0 then TrueVal could be either zero or undef. + // m_Zero match vectors containing some undef elements, but for scalars + // m_Undef should be used explicitly. + if (!match(MergedC, m_Zero()) && !match(MergedC, m_Undef())) + return nullptr; + + auto *FalseValI = cast<Instruction>(FalseVal); + auto *FrY = IC.InsertNewInstBefore(new FreezeInst(Y, Y->getName() + ".fr"), + *FalseValI); + IC.replaceOperand(*FalseValI, FalseValI->getOperand(0) == Y ? 0 : 1, FrY); + return IC.replaceInstUsesWith(SI, FalseValI); +} + +/// Transform patterns such as (a > b) ? a - b : 0 into usub.sat(a, b). +/// There are 8 commuted/swapped variants of this pattern. +/// TODO: Also support a - UMIN(a,b) patterns. +static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI, + const Value *TrueVal, + const Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + if (!ICmpInst::isUnsigned(Pred)) + return nullptr; + + // (b > a) ? 0 : a - b -> (b <= a) ? a - b : 0 + if (match(TrueVal, m_Zero())) { + Pred = ICmpInst::getInversePredicate(Pred); + std::swap(TrueVal, FalseVal); + } + if (!match(FalseVal, m_Zero())) + return nullptr; + + Value *A = ICI->getOperand(0); + Value *B = ICI->getOperand(1); + if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_ULT) { + // (b < a) ? a - b : 0 -> (a > b) ? a - b : 0 + std::swap(A, B); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + assert((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) && + "Unexpected isUnsigned predicate!"); + + // Ensure the sub is of the form: + // (a > b) ? a - b : 0 -> usub.sat(a, b) + // (a > b) ? b - a : 0 -> -usub.sat(a, b) + // Checking for both a-b and a+(-b) as a constant. + bool IsNegative = false; + const APInt *C; + if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A))) || + (match(A, m_APInt(C)) && + match(TrueVal, m_Add(m_Specific(B), m_SpecificInt(-*C))))) + IsNegative = true; + else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B))) && + !(match(B, m_APInt(C)) && + match(TrueVal, m_Add(m_Specific(A), m_SpecificInt(-*C))))) + return nullptr; + + // If we are adding a negate and the sub and icmp are used anywhere else, we + // would end up with more instructions. + if (IsNegative && !TrueVal->hasOneUse() && !ICI->hasOneUse()) + return nullptr; + + // (a > b) ? a - b : 0 -> usub.sat(a, b) + // (a > b) ? b - a : 0 -> -usub.sat(a, b) + Value *Result = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, B); + if (IsNegative) + Result = Builder.CreateNeg(Result); + return Result; +} + +static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder) { + if (!Cmp->hasOneUse()) + return nullptr; + + // Match unsigned saturated add with constant. + Value *Cmp0 = Cmp->getOperand(0); + Value *Cmp1 = Cmp->getOperand(1); + ICmpInst::Predicate Pred = Cmp->getPredicate(); + Value *X; + const APInt *C, *CmpC; + if (Pred == ICmpInst::ICMP_ULT && + match(TVal, m_Add(m_Value(X), m_APInt(C))) && X == Cmp0 && + match(FVal, m_AllOnes()) && match(Cmp1, m_APInt(CmpC)) && *CmpC == ~*C) { + // (X u< ~C) ? (X + C) : -1 --> uadd.sat(X, C) + return Builder.CreateBinaryIntrinsic( + Intrinsic::uadd_sat, X, ConstantInt::get(X->getType(), *C)); + } + + // Match unsigned saturated add of 2 variables with an unnecessary 'not'. + // There are 8 commuted variants. + // Canonicalize -1 (saturated result) to true value of the select. + if (match(FVal, m_AllOnes())) { + std::swap(TVal, FVal); + Pred = CmpInst::getInversePredicate(Pred); + } + if (!match(TVal, m_AllOnes())) + return nullptr; + + // Canonicalize predicate to less-than or less-or-equal-than. + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) { + std::swap(Cmp0, Cmp1); + Pred = CmpInst::getSwappedPredicate(Pred); + } + if (Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_ULE) + return nullptr; + + // Match unsigned saturated add of 2 variables with an unnecessary 'not'. + // Strictness of the comparison is irrelevant. + Value *Y; + if (match(Cmp0, m_Not(m_Value(X))) && + match(FVal, m_c_Add(m_Specific(X), m_Value(Y))) && Y == Cmp1) { + // (~X u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y) + // (~X u< Y) ? -1 : (Y + X) --> uadd.sat(X, Y) + return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, X, Y); + } + // The 'not' op may be included in the sum but not the compare. + // Strictness of the comparison is irrelevant. + X = Cmp0; + Y = Cmp1; + if (match(FVal, m_c_Add(m_Not(m_Specific(X)), m_Specific(Y)))) { + // (X u< Y) ? -1 : (~X + Y) --> uadd.sat(~X, Y) + // (X u< Y) ? -1 : (Y + ~X) --> uadd.sat(Y, ~X) + BinaryOperator *BO = cast<BinaryOperator>(FVal); + return Builder.CreateBinaryIntrinsic( + Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1)); + } + // The overflow may be detected via the add wrapping round. + // This is only valid for strict comparison! + if (Pred == ICmpInst::ICMP_ULT && + match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) && + match(FVal, m_c_Add(m_Specific(Cmp1), m_Specific(Y)))) { + // ((X + Y) u< X) ? -1 : (X + Y) --> uadd.sat(X, Y) + // ((X + Y) u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y) + return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp1, Y); + } + + return nullptr; +} + +/// Fold the following code sequence: +/// \code +/// int a = ctlz(x & -x); +// x ? 31 - a : a; +/// \code +/// +/// into: +/// cttz(x) +static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal, + Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits(); + if (!ICI->isEquality() || !match(ICI->getOperand(1), m_Zero())) + return nullptr; + + if (ICI->getPredicate() == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + if (!match(FalseVal, + m_Xor(m_Deferred(TrueVal), m_SpecificInt(BitWidth - 1)))) + return nullptr; + + if (!match(TrueVal, m_Intrinsic<Intrinsic::ctlz>())) + return nullptr; + + Value *X = ICI->getOperand(0); + auto *II = cast<IntrinsicInst>(TrueVal); + if (!match(II->getOperand(0), m_c_And(m_Specific(X), m_Neg(m_Specific(X))))) + return nullptr; + + Function *F = Intrinsic::getDeclaration(II->getModule(), Intrinsic::cttz, + II->getType()); + return CallInst::Create(F, {X, II->getArgOperand(1)}); +} + +/// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single +/// call to cttz/ctlz with flag 'is_zero_poison' cleared. +/// +/// For example, we can fold the following code sequence: +/// \code +/// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 true) +/// %1 = icmp ne i32 %x, 0 +/// %2 = select i1 %1, i32 %0, i32 32 +/// \code +/// +/// into: +/// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false) +static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + + // Check if the condition value compares a value for equality against zero. + if (!ICI->isEquality() || !match(CmpRHS, m_Zero())) + return nullptr; + + Value *SelectArg = FalseVal; + Value *ValueOnZero = TrueVal; + if (Pred == ICmpInst::ICMP_NE) + std::swap(SelectArg, ValueOnZero); + + // Skip zero extend/truncate. + Value *Count = nullptr; + if (!match(SelectArg, m_ZExt(m_Value(Count))) && + !match(SelectArg, m_Trunc(m_Value(Count)))) + Count = SelectArg; + + // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the + // input to the cttz/ctlz is used as LHS for the compare instruction. + if (!match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) && + !match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) + return nullptr; + + IntrinsicInst *II = cast<IntrinsicInst>(Count); + + // Check if the value propagated on zero is a constant number equal to the + // sizeof in bits of 'Count'. + unsigned SizeOfInBits = Count->getType()->getScalarSizeInBits(); + if (match(ValueOnZero, m_SpecificInt(SizeOfInBits))) { + // Explicitly clear the 'is_zero_poison' flag. It's always valid to go from + // true to false on this flag, so we can replace it for all users. + II->setArgOperand(1, ConstantInt::getFalse(II->getContext())); + return SelectArg; + } + + // The ValueOnZero is not the bitwidth. But if the cttz/ctlz (and optional + // zext/trunc) have one use (ending at the select), the cttz/ctlz result will + // not be used if the input is zero. Relax to 'zero is poison' for that case. + if (II->hasOneUse() && SelectArg->hasOneUse() && + !match(II->getArgOperand(1), m_One())) + II->setArgOperand(1, ConstantInt::getTrue(II->getContext())); + + return nullptr; +} + +/// Return true if we find and adjust an icmp+select pattern where the compare +/// is with a constant that can be incremented or decremented to match the +/// minimum or maximum idiom. +static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *CmpLHS = Cmp.getOperand(0); + Value *CmpRHS = Cmp.getOperand(1); + Value *TrueVal = Sel.getTrueValue(); + Value *FalseVal = Sel.getFalseValue(); + + // We may move or edit the compare, so make sure the select is the only user. + const APInt *CmpC; + if (!Cmp.hasOneUse() || !match(CmpRHS, m_APInt(CmpC))) + return false; + + // These transforms only work for selects of integers or vector selects of + // integer vectors. + Type *SelTy = Sel.getType(); + auto *SelEltTy = dyn_cast<IntegerType>(SelTy->getScalarType()); + if (!SelEltTy || SelTy->isVectorTy() != Cmp.getType()->isVectorTy()) + return false; + + Constant *AdjustedRHS; + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT) + AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC + 1); + else if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) + AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC - 1); + else + return false; + + // X > C ? X : C+1 --> X < C+1 ? C+1 : X + // X < C ? X : C-1 --> X > C-1 ? C-1 : X + if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || + (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { + ; // Nothing to do here. Values match without any sign/zero extension. + } + // Types do not match. Instead of calculating this with mixed types, promote + // all to the larger type. This enables scalar evolution to analyze this + // expression. + else if (CmpRHS->getType()->getScalarSizeInBits() < SelEltTy->getBitWidth()) { + Constant *SextRHS = ConstantExpr::getSExt(AdjustedRHS, SelTy); + + // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X + // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X + // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X + // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X + if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && SextRHS == FalseVal) { + CmpLHS = TrueVal; + AdjustedRHS = SextRHS; + } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) && + SextRHS == TrueVal) { + CmpLHS = FalseVal; + AdjustedRHS = SextRHS; + } else if (Cmp.isUnsigned()) { + Constant *ZextRHS = ConstantExpr::getZExt(AdjustedRHS, SelTy); + // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X + // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X + // zext + signed compare cannot be changed: + // 0xff <s 0x00, but 0x00ff >s 0x0000 + if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && ZextRHS == FalseVal) { + CmpLHS = TrueVal; + AdjustedRHS = ZextRHS; + } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) && + ZextRHS == TrueVal) { + CmpLHS = FalseVal; + AdjustedRHS = ZextRHS; + } else { + return false; + } + } else { + return false; + } + } else { + return false; + } + + Pred = ICmpInst::getSwappedPredicate(Pred); + CmpRHS = AdjustedRHS; + std::swap(FalseVal, TrueVal); + Cmp.setPredicate(Pred); + Cmp.setOperand(0, CmpLHS); + Cmp.setOperand(1, CmpRHS); + Sel.setOperand(1, TrueVal); + Sel.setOperand(2, FalseVal); + Sel.swapProfMetadata(); + + // Move the compare instruction right before the select instruction. Otherwise + // the sext/zext value may be defined after the compare instruction uses it. + Cmp.moveBefore(&Sel); + + return true; +} + +static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, + InstCombinerImpl &IC) { + Value *LHS, *RHS; + // TODO: What to do with pointer min/max patterns? + if (!Sel.getType()->isIntOrIntVectorTy()) + return nullptr; + + SelectPatternFlavor SPF = matchSelectPattern(&Sel, LHS, RHS).Flavor; + if (SPF == SelectPatternFlavor::SPF_ABS || + SPF == SelectPatternFlavor::SPF_NABS) { + if (!Cmp.hasOneUse() && !RHS->hasOneUse()) + return nullptr; // TODO: Relax this restriction. + + // Note that NSW flag can only be propagated for normal, non-negated abs! + bool IntMinIsPoison = SPF == SelectPatternFlavor::SPF_ABS && + match(RHS, m_NSWNeg(m_Specific(LHS))); + Constant *IntMinIsPoisonC = + ConstantInt::get(Type::getInt1Ty(Sel.getContext()), IntMinIsPoison); + Instruction *Abs = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC); + + if (SPF == SelectPatternFlavor::SPF_NABS) + return BinaryOperator::CreateNeg(Abs); // Always without NSW flag! + return IC.replaceInstUsesWith(Sel, Abs); + } + + if (SelectPatternResult::isMinOrMax(SPF)) { + Intrinsic::ID IntrinsicID; + switch (SPF) { + case SelectPatternFlavor::SPF_UMIN: + IntrinsicID = Intrinsic::umin; + break; + case SelectPatternFlavor::SPF_UMAX: + IntrinsicID = Intrinsic::umax; + break; + case SelectPatternFlavor::SPF_SMIN: + IntrinsicID = Intrinsic::smin; + break; + case SelectPatternFlavor::SPF_SMAX: + IntrinsicID = Intrinsic::smax; + break; + default: + llvm_unreachable("Unexpected SPF"); + } + return IC.replaceInstUsesWith( + Sel, IC.Builder.CreateBinaryIntrinsic(IntrinsicID, LHS, RHS)); + } + + return nullptr; +} + +/// If we have a select with an equality comparison, then we know the value in +/// one of the arms of the select. See if substituting this value into an arm +/// and simplifying the result yields the same value as the other arm. +/// +/// To make this transform safe, we must drop poison-generating flags +/// (nsw, etc) if we simplified to a binop because the select may be guarding +/// that poison from propagating. If the existing binop already had no +/// poison-generating flags, then this transform can be done by instsimplify. +/// +/// Consider: +/// %cmp = icmp eq i32 %x, 2147483647 +/// %add = add nsw i32 %x, 1 +/// %sel = select i1 %cmp, i32 -2147483648, i32 %add +/// +/// We can't replace %sel with %add unless we strip away the flags. +/// TODO: Wrapping flags could be preserved in some cases with better analysis. +Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, + ICmpInst &Cmp) { + // Value equivalence substitution requires an all-or-nothing replacement. + // It does not make sense for a vector compare where each lane is chosen + // independently. + if (!Cmp.isEquality() || Cmp.getType()->isVectorTy()) + return nullptr; + + // Canonicalize the pattern to ICMP_EQ by swapping the select operands. + Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue(); + bool Swapped = false; + if (Cmp.getPredicate() == ICmpInst::ICMP_NE) { + std::swap(TrueVal, FalseVal); + Swapped = true; + } + + // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand. + // Make sure Y cannot be undef though, as we might pick different values for + // undef in the icmp and in f(Y). Additionally, take care to avoid replacing + // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite + // replacement cycle. + Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); + if (TrueVal != CmpLHS && + isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) { + if (Value *V = simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ, + /* AllowRefinement */ true)) + return replaceOperand(Sel, Swapped ? 2 : 1, V); + + // Even if TrueVal does not simplify, we can directly replace a use of + // CmpLHS with CmpRHS, as long as the instruction is not used anywhere + // else and is safe to speculatively execute (we may end up executing it + // with different operands, which should not cause side-effects or trigger + // undefined behavior). Only do this if CmpRHS is a constant, as + // profitability is not clear for other cases. + // FIXME: The replacement could be performed recursively. + if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant())) + if (auto *I = dyn_cast<Instruction>(TrueVal)) + if (I->hasOneUse() && isSafeToSpeculativelyExecute(I)) + for (Use &U : I->operands()) + if (U == CmpLHS) { + replaceUse(U, CmpRHS); + return &Sel; + } + } + if (TrueVal != CmpRHS && + isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT)) + if (Value *V = simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ, + /* AllowRefinement */ true)) + return replaceOperand(Sel, Swapped ? 2 : 1, V); + + auto *FalseInst = dyn_cast<Instruction>(FalseVal); + if (!FalseInst) + return nullptr; + + // InstSimplify already performed this fold if it was possible subject to + // current poison-generating flags. Try the transform again with + // poison-generating flags temporarily dropped. + bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false; + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) { + WasNUW = OBO->hasNoUnsignedWrap(); + WasNSW = OBO->hasNoSignedWrap(); + FalseInst->setHasNoUnsignedWrap(false); + FalseInst->setHasNoSignedWrap(false); + } + if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) { + WasExact = PEO->isExact(); + FalseInst->setIsExact(false); + } + if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) { + WasInBounds = GEP->isInBounds(); + GEP->setIsInBounds(false); + } + + // Try each equivalence substitution possibility. + // We have an 'EQ' comparison, so the select's false value will propagate. + // Example: + // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 + if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ, + /* AllowRefinement */ false) == TrueVal || + simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ, + /* AllowRefinement */ false) == TrueVal) { + return replaceInstUsesWith(Sel, FalseVal); + } + + // Restore poison-generating flags if the transform did not apply. + if (WasNUW) + FalseInst->setHasNoUnsignedWrap(); + if (WasNSW) + FalseInst->setHasNoSignedWrap(); + if (WasExact) + FalseInst->setIsExact(); + if (WasInBounds) + cast<GetElementPtrInst>(FalseInst)->setIsInBounds(); + + return nullptr; +} + +// See if this is a pattern like: +// %old_cmp1 = icmp slt i32 %x, C2 +// %old_replacement = select i1 %old_cmp1, i32 %target_low, i32 %target_high +// %old_x_offseted = add i32 %x, C1 +// %old_cmp0 = icmp ult i32 %old_x_offseted, C0 +// %r = select i1 %old_cmp0, i32 %x, i32 %old_replacement +// This can be rewritten as more canonical pattern: +// %new_cmp1 = icmp slt i32 %x, -C1 +// %new_cmp2 = icmp sge i32 %x, C0-C1 +// %new_clamped_low = select i1 %new_cmp1, i32 %target_low, i32 %x +// %r = select i1 %new_cmp2, i32 %target_high, i32 %new_clamped_low +// Iff -C1 s<= C2 s<= C0-C1 +// Also ULT predicate can also be UGT iff C0 != -1 (+invert result) +// SLT predicate can also be SGT iff C2 != INT_MAX (+invert res.) +static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, + InstCombiner::BuilderTy &Builder) { + Value *X = Sel0.getTrueValue(); + Value *Sel1 = Sel0.getFalseValue(); + + // First match the condition of the outermost select. + // Said condition must be one-use. + if (!Cmp0.hasOneUse()) + return nullptr; + ICmpInst::Predicate Pred0 = Cmp0.getPredicate(); + Value *Cmp00 = Cmp0.getOperand(0); + Constant *C0; + if (!match(Cmp0.getOperand(1), + m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0)))) + return nullptr; + + if (!isa<SelectInst>(Sel1)) { + Pred0 = ICmpInst::getInversePredicate(Pred0); + std::swap(X, Sel1); + } + + // Canonicalize Cmp0 into ult or uge. + // FIXME: we shouldn't care about lanes that are 'undef' in the end? + switch (Pred0) { + case ICmpInst::Predicate::ICMP_ULT: + case ICmpInst::Predicate::ICMP_UGE: + // Although icmp ult %x, 0 is an unusual thing to try and should generally + // have been simplified, it does not verify with undef inputs so ensure we + // are not in a strange state. + if (!match(C0, m_SpecificInt_ICMP( + ICmpInst::Predicate::ICMP_NE, + APInt::getZero(C0->getType()->getScalarSizeInBits())))) + return nullptr; + break; // Great! + case ICmpInst::Predicate::ICMP_ULE: + case ICmpInst::Predicate::ICMP_UGT: + // We want to canonicalize it to 'ult' or 'uge', so we'll need to increment + // C0, which again means it must not have any all-ones elements. + if (!match(C0, + m_SpecificInt_ICMP( + ICmpInst::Predicate::ICMP_NE, + APInt::getAllOnes(C0->getType()->getScalarSizeInBits())))) + return nullptr; // Can't do, have all-ones element[s]. + Pred0 = ICmpInst::getFlippedStrictnessPredicate(Pred0); + C0 = InstCombiner::AddOne(C0); + break; + default: + return nullptr; // Unknown predicate. + } + + // Now that we've canonicalized the ICmp, we know the X we expect; + // the select in other hand should be one-use. + if (!Sel1->hasOneUse()) + return nullptr; + + // If the types do not match, look through any truncs to the underlying + // instruction. + if (Cmp00->getType() != X->getType() && X->hasOneUse()) + match(X, m_TruncOrSelf(m_Value(X))); + + // We now can finish matching the condition of the outermost select: + // it should either be the X itself, or an addition of some constant to X. + Constant *C1; + if (Cmp00 == X) + C1 = ConstantInt::getNullValue(X->getType()); + else if (!match(Cmp00, + m_Add(m_Specific(X), + m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C1))))) + return nullptr; + + Value *Cmp1; + ICmpInst::Predicate Pred1; + Constant *C2; + Value *ReplacementLow, *ReplacementHigh; + if (!match(Sel1, m_Select(m_Value(Cmp1), m_Value(ReplacementLow), + m_Value(ReplacementHigh))) || + !match(Cmp1, + m_ICmp(Pred1, m_Specific(X), + m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C2))))) + return nullptr; + + if (!Cmp1->hasOneUse() && (Cmp00 == X || !Cmp00->hasOneUse())) + return nullptr; // Not enough one-use instructions for the fold. + // FIXME: this restriction could be relaxed if Cmp1 can be reused as one of + // two comparisons we'll need to build. + + // Canonicalize Cmp1 into the form we expect. + // FIXME: we shouldn't care about lanes that are 'undef' in the end? + switch (Pred1) { + case ICmpInst::Predicate::ICMP_SLT: + break; + case ICmpInst::Predicate::ICMP_SLE: + // We'd have to increment C2 by one, and for that it must not have signed + // max element, but then it would have been canonicalized to 'slt' before + // we get here. So we can't do anything useful with 'sle'. + return nullptr; + case ICmpInst::Predicate::ICMP_SGT: + // We want to canonicalize it to 'slt', so we'll need to increment C2, + // which again means it must not have any signed max elements. + if (!match(C2, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_NE, + APInt::getSignedMaxValue( + C2->getType()->getScalarSizeInBits())))) + return nullptr; // Can't do, have signed max element[s]. + C2 = InstCombiner::AddOne(C2); + LLVM_FALLTHROUGH; + case ICmpInst::Predicate::ICMP_SGE: + // Also non-canonical, but here we don't need to change C2, + // so we don't have any restrictions on C2, so we can just handle it. + Pred1 = ICmpInst::Predicate::ICMP_SLT; + std::swap(ReplacementLow, ReplacementHigh); + break; + default: + return nullptr; // Unknown predicate. + } + assert(Pred1 == ICmpInst::Predicate::ICMP_SLT && + "Unexpected predicate type."); + + // The thresholds of this clamp-like pattern. + auto *ThresholdLowIncl = ConstantExpr::getNeg(C1); + auto *ThresholdHighExcl = ConstantExpr::getSub(C0, C1); + + assert((Pred0 == ICmpInst::Predicate::ICMP_ULT || + Pred0 == ICmpInst::Predicate::ICMP_UGE) && + "Unexpected predicate type."); + if (Pred0 == ICmpInst::Predicate::ICMP_UGE) + std::swap(ThresholdLowIncl, ThresholdHighExcl); + + // The fold has a precondition 1: C2 s>= ThresholdLow + auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2, + ThresholdLowIncl); + if (!match(Precond1, m_One())) + return nullptr; + // The fold has a precondition 2: C2 s<= ThresholdHigh + auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2, + ThresholdHighExcl); + if (!match(Precond2, m_One())) + return nullptr; + + // If we are matching from a truncated input, we need to sext the + // ReplacementLow and ReplacementHigh values. Only do the transform if they + // are free to extend due to being constants. + if (X->getType() != Sel0.getType()) { + Constant *LowC, *HighC; + if (!match(ReplacementLow, m_ImmConstant(LowC)) || + !match(ReplacementHigh, m_ImmConstant(HighC))) + return nullptr; + ReplacementLow = ConstantExpr::getSExt(LowC, X->getType()); + ReplacementHigh = ConstantExpr::getSExt(HighC, X->getType()); + } + + // All good, finally emit the new pattern. + Value *ShouldReplaceLow = Builder.CreateICmpSLT(X, ThresholdLowIncl); + Value *ShouldReplaceHigh = Builder.CreateICmpSGE(X, ThresholdHighExcl); + Value *MaybeReplacedLow = + Builder.CreateSelect(ShouldReplaceLow, ReplacementLow, X); + + // Create the final select. If we looked through a truncate above, we will + // need to retruncate the result. + Value *MaybeReplacedHigh = Builder.CreateSelect( + ShouldReplaceHigh, ReplacementHigh, MaybeReplacedLow); + return Builder.CreateTrunc(MaybeReplacedHigh, Sel0.getType()); +} + +// If we have +// %cmp = icmp [canonical predicate] i32 %x, C0 +// %r = select i1 %cmp, i32 %y, i32 C1 +// Where C0 != C1 and %x may be different from %y, see if the constant that we +// will have if we flip the strictness of the predicate (i.e. without changing +// the result) is identical to the C1 in select. If it matches we can change +// original comparison to one with swapped predicate, reuse the constant, +// and swap the hands of select. +static Instruction * +tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, + InstCombinerImpl &IC) { + ICmpInst::Predicate Pred; + Value *X; + Constant *C0; + if (!match(&Cmp, m_OneUse(m_ICmp( + Pred, m_Value(X), + m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0)))))) + return nullptr; + + // If comparison predicate is non-relational, we won't be able to do anything. + if (ICmpInst::isEquality(Pred)) + return nullptr; + + // If comparison predicate is non-canonical, then we certainly won't be able + // to make it canonical; canonicalizeCmpWithConstant() already tried. + if (!InstCombiner::isCanonicalPredicate(Pred)) + return nullptr; + + // If the [input] type of comparison and select type are different, lets abort + // for now. We could try to compare constants with trunc/[zs]ext though. + if (C0->getType() != Sel.getType()) + return nullptr; + + // ULT with 'add' of a constant is canonical. See foldICmpAddConstant(). + // FIXME: Are there more magic icmp predicate+constant pairs we must avoid? + // Or should we just abandon this transform entirely? + if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant()))) + return nullptr; + + + Value *SelVal0, *SelVal1; // We do not care which one is from where. + match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1))); + // At least one of these values we are selecting between must be a constant + // else we'll never succeed. + if (!match(SelVal0, m_AnyIntegralConstant()) && + !match(SelVal1, m_AnyIntegralConstant())) + return nullptr; + + // Does this constant C match any of the `select` values? + auto MatchesSelectValue = [SelVal0, SelVal1](Constant *C) { + return C->isElementWiseEqual(SelVal0) || C->isElementWiseEqual(SelVal1); + }; + + // If C0 *already* matches true/false value of select, we are done. + if (MatchesSelectValue(C0)) + return nullptr; + + // Check the constant we'd have with flipped-strictness predicate. + auto FlippedStrictness = + InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0); + if (!FlippedStrictness) + return nullptr; + + // If said constant doesn't match either, then there is no hope, + if (!MatchesSelectValue(FlippedStrictness->second)) + return nullptr; + + // It matched! Lets insert the new comparison just before select. + InstCombiner::BuilderTy::InsertPointGuard Guard(IC.Builder); + IC.Builder.SetInsertPoint(&Sel); + + Pred = ICmpInst::getSwappedPredicate(Pred); // Yes, swapped. + Value *NewCmp = IC.Builder.CreateICmp(Pred, X, FlippedStrictness->second, + Cmp.getName() + ".inv"); + IC.replaceOperand(Sel, 0, NewCmp); + Sel.swapValues(); + Sel.swapProfMetadata(); + + return &Sel; +} + +static Instruction *foldSelectZeroOrOnes(ICmpInst *Cmp, Value *TVal, + Value *FVal, + InstCombiner::BuilderTy &Builder) { + if (!Cmp->hasOneUse()) + return nullptr; + + const APInt *CmpC; + if (!match(Cmp->getOperand(1), m_APIntAllowUndef(CmpC))) + return nullptr; + + // (X u< 2) ? -X : -1 --> sext (X != 0) + Value *X = Cmp->getOperand(0); + if (Cmp->getPredicate() == ICmpInst::ICMP_ULT && *CmpC == 2 && + match(TVal, m_Neg(m_Specific(X))) && match(FVal, m_AllOnes())) + return new SExtInst(Builder.CreateIsNotNull(X), TVal->getType()); + + // (X u> 1) ? -1 : -X --> sext (X != 0) + if (Cmp->getPredicate() == ICmpInst::ICMP_UGT && *CmpC == 1 && + match(FVal, m_Neg(m_Specific(X))) && match(TVal, m_AllOnes())) + return new SExtInst(Builder.CreateIsNotNull(X), TVal->getType()); + + return nullptr; +} + +static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI) { + const APInt *CmpC; + Value *V; + CmpInst::Predicate Pred; + if (!match(ICI, m_ICmp(Pred, m_Value(V), m_APInt(CmpC)))) + return nullptr; + + BinaryOperator *BO; + const APInt *C; + CmpInst::Predicate CPred; + if (match(&SI, m_Select(m_Specific(ICI), m_APInt(C), m_BinOp(BO)))) + CPred = ICI->getPredicate(); + else if (match(&SI, m_Select(m_Specific(ICI), m_BinOp(BO), m_APInt(C)))) + CPred = ICI->getInversePredicate(); + else + return nullptr; + + const APInt *BinOpC; + if (!match(BO, m_BinOp(m_Specific(V), m_APInt(BinOpC)))) + return nullptr; + + ConstantRange R = ConstantRange::makeExactICmpRegion(CPred, *CmpC) + .binaryOp(BO->getOpcode(), *BinOpC); + if (R == *C) { + BO->dropPoisonGeneratingFlags(); + return BO; + } + return nullptr; +} + +/// Visit a SelectInst that has an ICmpInst as its first operand. +Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, + ICmpInst *ICI) { + if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI)) + return NewSel; + + if (Instruction *NewSPF = canonicalizeSPF(SI, *ICI, *this)) + return NewSPF; + + if (Value *V = foldSelectInstWithICmpConst(SI, ICI)) + return replaceInstUsesWith(SI, V); + + if (Value *V = canonicalizeClampLike(SI, *ICI, Builder)) + return replaceInstUsesWith(SI, V); + + if (Instruction *NewSel = + tryToReuseConstantFromSelectInComparison(SI, *ICI, *this)) + return NewSel; + + bool Changed = adjustMinMax(SI, *ICI); + + if (Value *V = foldSelectICmpAnd(SI, ICI, Builder)) + return replaceInstUsesWith(SI, V); + + // NOTE: if we wanted to, this is where to detect integer MIN/MAX + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) { + if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) { + // Transform (X == C) ? X : Y -> (X == C) ? C : Y + SI.setOperand(1, CmpRHS); + Changed = true; + } else if (CmpLHS == FalseVal && Pred == ICmpInst::ICMP_NE) { + // Transform (X != C) ? Y : X -> (X != C) ? Y : C + SI.setOperand(2, CmpRHS); + Changed = true; + } + } + + // Canonicalize a signbit condition to use zero constant by swapping: + // (CmpLHS > -1) ? TV : FV --> (CmpLHS < 0) ? FV : TV + // To avoid conflicts (infinite loops) with other canonicalizations, this is + // not applied with any constant select arm. + if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes()) && + !match(TrueVal, m_Constant()) && !match(FalseVal, m_Constant()) && + ICI->hasOneUse()) { + InstCombiner::BuilderTy::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(&SI); + Value *IsNeg = Builder.CreateIsNeg(CmpLHS, ICI->getName()); + replaceOperand(SI, 0, IsNeg); + SI.swapValues(); + SI.swapProfMetadata(); + return &SI; + } + + // FIXME: This code is nearly duplicated in InstSimplify. Using/refactoring + // decomposeBitTestICmp() might help. + { + unsigned BitWidth = + DL.getTypeSizeInBits(TrueVal->getType()->getScalarType()); + APInt MinSignedValue = APInt::getSignedMinValue(BitWidth); + Value *X; + const APInt *Y, *C; + bool TrueWhenUnset; + bool IsBitTest = false; + if (ICmpInst::isEquality(Pred) && + match(CmpLHS, m_And(m_Value(X), m_Power2(Y))) && + match(CmpRHS, m_Zero())) { + IsBitTest = true; + TrueWhenUnset = Pred == ICmpInst::ICMP_EQ; + } else if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) { + X = CmpLHS; + Y = &MinSignedValue; + IsBitTest = true; + TrueWhenUnset = false; + } else if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) { + X = CmpLHS; + Y = &MinSignedValue; + IsBitTest = true; + TrueWhenUnset = true; + } + if (IsBitTest) { + Value *V = nullptr; + // (X & Y) == 0 ? X : X ^ Y --> X & ~Y + if (TrueWhenUnset && TrueVal == X && + match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) + V = Builder.CreateAnd(X, ~(*Y)); + // (X & Y) != 0 ? X ^ Y : X --> X & ~Y + else if (!TrueWhenUnset && FalseVal == X && + match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) + V = Builder.CreateAnd(X, ~(*Y)); + // (X & Y) == 0 ? X ^ Y : X --> X | Y + else if (TrueWhenUnset && FalseVal == X && + match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) + V = Builder.CreateOr(X, *Y); + // (X & Y) != 0 ? X : X ^ Y --> X | Y + else if (!TrueWhenUnset && TrueVal == X && + match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) + V = Builder.CreateOr(X, *Y); + + if (V) + return replaceInstUsesWith(SI, V); + } + } + + if (Instruction *V = + foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder)) + return V; + + if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder)) + return V; + + if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder)) + return V; + + if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + + if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + + if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + + if (Value *V = canonicalizeSaturatedSubtract(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + + if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + + return Changed ? &SI : nullptr; +} + +/// SI is a select whose condition is a PHI node (but the two may be in +/// different blocks). See if the true/false values (V) are live in all of the +/// predecessor blocks of the PHI. For example, cases like this can't be mapped: +/// +/// X = phi [ C1, BB1], [C2, BB2] +/// Y = add +/// Z = select X, Y, 0 +/// +/// because Y is not live in BB1/BB2. +static bool canSelectOperandBeMappingIntoPredBlock(const Value *V, + const SelectInst &SI) { + // If the value is a non-instruction value like a constant or argument, it + // can always be mapped. + const Instruction *I = dyn_cast<Instruction>(V); + if (!I) return true; + + // If V is a PHI node defined in the same block as the condition PHI, we can + // map the arguments. + const PHINode *CondPHI = cast<PHINode>(SI.getCondition()); + + if (const PHINode *VP = dyn_cast<PHINode>(I)) + if (VP->getParent() == CondPHI->getParent()) + return true; + + // Otherwise, if the PHI and select are defined in the same block and if V is + // defined in a different block, then we can transform it. + if (SI.getParent() == CondPHI->getParent() && + I->getParent() != CondPHI->getParent()) + return true; + + // Otherwise we have a 'hard' case and we can't tell without doing more + // detailed dominator based analysis, punt. + return false; +} + +/// We have an SPF (e.g. a min or max) of an SPF of the form: +/// SPF2(SPF1(A, B), C) +Instruction *InstCombinerImpl::foldSPFofSPF(Instruction *Inner, + SelectPatternFlavor SPF1, Value *A, + Value *B, Instruction &Outer, + SelectPatternFlavor SPF2, + Value *C) { + if (Outer.getType() != Inner->getType()) + return nullptr; + + if (C == A || C == B) { + // MAX(MAX(A, B), B) -> MAX(A, B) + // MIN(MIN(a, b), a) -> MIN(a, b) + // TODO: This could be done in instsimplify. + if (SPF1 == SPF2 && SelectPatternResult::isMinOrMax(SPF1)) + return replaceInstUsesWith(Outer, Inner); + } + + return nullptr; +} + +/// Turn select C, (X + Y), (X - Y) --> (X + (select C, Y, (-Y))). +/// This is even legal for FP. +static Instruction *foldAddSubSelect(SelectInst &SI, + InstCombiner::BuilderTy &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + auto *TI = dyn_cast<Instruction>(TrueVal); + auto *FI = dyn_cast<Instruction>(FalseVal); + if (!TI || !FI || !TI->hasOneUse() || !FI->hasOneUse()) + return nullptr; + + Instruction *AddOp = nullptr, *SubOp = nullptr; + if ((TI->getOpcode() == Instruction::Sub && + FI->getOpcode() == Instruction::Add) || + (TI->getOpcode() == Instruction::FSub && + FI->getOpcode() == Instruction::FAdd)) { + AddOp = FI; + SubOp = TI; + } else if ((FI->getOpcode() == Instruction::Sub && + TI->getOpcode() == Instruction::Add) || + (FI->getOpcode() == Instruction::FSub && + TI->getOpcode() == Instruction::FAdd)) { + AddOp = TI; + SubOp = FI; + } + + if (AddOp) { + Value *OtherAddOp = nullptr; + if (SubOp->getOperand(0) == AddOp->getOperand(0)) { + OtherAddOp = AddOp->getOperand(1); + } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) { + OtherAddOp = AddOp->getOperand(0); + } + + if (OtherAddOp) { + // So at this point we know we have (Y -> OtherAddOp): + // select C, (add X, Y), (sub X, Z) + Value *NegVal; // Compute -Z + if (SI.getType()->isFPOrFPVectorTy()) { + NegVal = Builder.CreateFNeg(SubOp->getOperand(1)); + if (Instruction *NegInst = dyn_cast<Instruction>(NegVal)) { + FastMathFlags Flags = AddOp->getFastMathFlags(); + Flags &= SubOp->getFastMathFlags(); + NegInst->setFastMathFlags(Flags); + } + } else { + NegVal = Builder.CreateNeg(SubOp->getOperand(1)); + } + + Value *NewTrueOp = OtherAddOp; + Value *NewFalseOp = NegVal; + if (AddOp != TI) + std::swap(NewTrueOp, NewFalseOp); + Value *NewSel = Builder.CreateSelect(CondVal, NewTrueOp, NewFalseOp, + SI.getName() + ".p", &SI); + + if (SI.getType()->isFPOrFPVectorTy()) { + Instruction *RI = + BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel); + + FastMathFlags Flags = AddOp->getFastMathFlags(); + Flags &= SubOp->getFastMathFlags(); + RI->setFastMathFlags(Flags); + return RI; + } else + return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel); + } + } + return nullptr; +} + +/// Turn X + Y overflows ? -1 : X + Y -> uadd_sat X, Y +/// And X - Y overflows ? 0 : X - Y -> usub_sat X, Y +/// Along with a number of patterns similar to: +/// X + Y overflows ? (X < 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y +/// X - Y overflows ? (X > 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y +static Instruction * +foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + WithOverflowInst *II; + if (!match(CondVal, m_ExtractValue<1>(m_WithOverflowInst(II))) || + !match(FalseVal, m_ExtractValue<0>(m_Specific(II)))) + return nullptr; + + Value *X = II->getLHS(); + Value *Y = II->getRHS(); + + auto IsSignedSaturateLimit = [&](Value *Limit, bool IsAdd) { + Type *Ty = Limit->getType(); + + ICmpInst::Predicate Pred; + Value *TrueVal, *FalseVal, *Op; + const APInt *C; + if (!match(Limit, m_Select(m_ICmp(Pred, m_Value(Op), m_APInt(C)), + m_Value(TrueVal), m_Value(FalseVal)))) + return false; + + auto IsZeroOrOne = [](const APInt &C) { return C.isZero() || C.isOne(); }; + auto IsMinMax = [&](Value *Min, Value *Max) { + APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); + APInt MaxVal = APInt::getSignedMaxValue(Ty->getScalarSizeInBits()); + return match(Min, m_SpecificInt(MinVal)) && + match(Max, m_SpecificInt(MaxVal)); + }; + + if (Op != X && Op != Y) + return false; + + if (IsAdd) { + // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + if (Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) && + IsMinMax(TrueVal, FalseVal)) + return true; + // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + if (Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) && + IsMinMax(FalseVal, TrueVal)) + return true; + } else { + // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + if (Op == X && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C + 1) && + IsMinMax(TrueVal, FalseVal)) + return true; + // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + if (Op == X && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 2) && + IsMinMax(FalseVal, TrueVal)) + return true; + // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + if (Op == Y && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) && + IsMinMax(FalseVal, TrueVal)) + return true; + // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + if (Op == Y && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) && + IsMinMax(TrueVal, FalseVal)) + return true; + } + + return false; + }; + + Intrinsic::ID NewIntrinsicID; + if (II->getIntrinsicID() == Intrinsic::uadd_with_overflow && + match(TrueVal, m_AllOnes())) + // X + Y overflows ? -1 : X + Y -> uadd_sat X, Y + NewIntrinsicID = Intrinsic::uadd_sat; + else if (II->getIntrinsicID() == Intrinsic::usub_with_overflow && + match(TrueVal, m_Zero())) + // X - Y overflows ? 0 : X - Y -> usub_sat X, Y + NewIntrinsicID = Intrinsic::usub_sat; + else if (II->getIntrinsicID() == Intrinsic::sadd_with_overflow && + IsSignedSaturateLimit(TrueVal, /*IsAdd=*/true)) + // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + NewIntrinsicID = Intrinsic::sadd_sat; + else if (II->getIntrinsicID() == Intrinsic::ssub_with_overflow && + IsSignedSaturateLimit(TrueVal, /*IsAdd=*/false)) + // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + NewIntrinsicID = Intrinsic::ssub_sat; + else + return nullptr; + + Function *F = + Intrinsic::getDeclaration(SI.getModule(), NewIntrinsicID, SI.getType()); + return CallInst::Create(F, {X, Y}); +} + +Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) { + Constant *C; + if (!match(Sel.getTrueValue(), m_Constant(C)) && + !match(Sel.getFalseValue(), m_Constant(C))) + return nullptr; + + Instruction *ExtInst; + if (!match(Sel.getTrueValue(), m_Instruction(ExtInst)) && + !match(Sel.getFalseValue(), m_Instruction(ExtInst))) + return nullptr; + + auto ExtOpcode = ExtInst->getOpcode(); + if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt) + return nullptr; + + // If we are extending from a boolean type or if we can create a select that + // has the same size operands as its condition, try to narrow the select. + Value *X = ExtInst->getOperand(0); + Type *SmallType = X->getType(); + Value *Cond = Sel.getCondition(); + auto *Cmp = dyn_cast<CmpInst>(Cond); + if (!SmallType->isIntOrIntVectorTy(1) && + (!Cmp || Cmp->getOperand(0)->getType() != SmallType)) + return nullptr; + + // If the constant is the same after truncation to the smaller type and + // extension to the original type, we can narrow the select. + Type *SelType = Sel.getType(); + Constant *TruncC = ConstantExpr::getTrunc(C, SmallType); + Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType); + if (ExtC == C && ExtInst->hasOneUse()) { + Value *TruncCVal = cast<Value>(TruncC); + if (ExtInst == Sel.getFalseValue()) + std::swap(X, TruncCVal); + + // select Cond, (ext X), C --> ext(select Cond, X, C') + // select Cond, C, (ext X) --> ext(select Cond, C', X) + Value *NewSel = Builder.CreateSelect(Cond, X, TruncCVal, "narrow", &Sel); + return CastInst::Create(Instruction::CastOps(ExtOpcode), NewSel, SelType); + } + + // If one arm of the select is the extend of the condition, replace that arm + // with the extension of the appropriate known bool value. + if (Cond == X) { + if (ExtInst == Sel.getTrueValue()) { + // select X, (sext X), C --> select X, -1, C + // select X, (zext X), C --> select X, 1, C + Constant *One = ConstantInt::getTrue(SmallType); + Constant *AllOnesOrOne = ConstantExpr::getCast(ExtOpcode, One, SelType); + return SelectInst::Create(Cond, AllOnesOrOne, C, "", nullptr, &Sel); + } else { + // select X, C, (sext X) --> select X, C, 0 + // select X, C, (zext X) --> select X, C, 0 + Constant *Zero = ConstantInt::getNullValue(SelType); + return SelectInst::Create(Cond, C, Zero, "", nullptr, &Sel); + } + } + + return nullptr; +} + +/// Try to transform a vector select with a constant condition vector into a +/// shuffle for easier combining with other shuffles and insert/extract. +static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { + Value *CondVal = SI.getCondition(); + Constant *CondC; + auto *CondValTy = dyn_cast<FixedVectorType>(CondVal->getType()); + if (!CondValTy || !match(CondVal, m_Constant(CondC))) + return nullptr; + + unsigned NumElts = CondValTy->getNumElements(); + SmallVector<int, 16> Mask; + Mask.reserve(NumElts); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = CondC->getAggregateElement(i); + if (!Elt) + return nullptr; + + if (Elt->isOneValue()) { + // If the select condition element is true, choose from the 1st vector. + Mask.push_back(i); + } else if (Elt->isNullValue()) { + // If the select condition element is false, choose from the 2nd vector. + Mask.push_back(i + NumElts); + } else if (isa<UndefValue>(Elt)) { + // Undef in a select condition (choose one of the operands) does not mean + // the same thing as undef in a shuffle mask (any value is acceptable), so + // give up. + return nullptr; + } else { + // Bail out on a constant expression. + return nullptr; + } + } + + return new ShuffleVectorInst(SI.getTrueValue(), SI.getFalseValue(), Mask); +} + +/// If we have a select of vectors with a scalar condition, try to convert that +/// to a vector select by splatting the condition. A splat may get folded with +/// other operations in IR and having all operands of a select be vector types +/// is likely better for vector codegen. +static Instruction *canonicalizeScalarSelectOfVecs(SelectInst &Sel, + InstCombinerImpl &IC) { + auto *Ty = dyn_cast<VectorType>(Sel.getType()); + if (!Ty) + return nullptr; + + // We can replace a single-use extract with constant index. + Value *Cond = Sel.getCondition(); + if (!match(Cond, m_OneUse(m_ExtractElt(m_Value(), m_ConstantInt())))) + return nullptr; + + // select (extelt V, Index), T, F --> select (splat V, Index), T, F + // Splatting the extracted condition reduces code (we could directly create a + // splat shuffle of the source vector to eliminate the intermediate step). + return IC.replaceOperand( + Sel, 0, IC.Builder.CreateVectorSplat(Ty->getElementCount(), Cond)); +} + +/// Reuse bitcasted operands between a compare and select: +/// select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) --> +/// bitcast (select (cmp (bitcast C), (bitcast D)), (bitcast C), (bitcast D)) +static Instruction *foldSelectCmpBitcasts(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + Value *Cond = Sel.getCondition(); + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + + CmpInst::Predicate Pred; + Value *A, *B; + if (!match(Cond, m_Cmp(Pred, m_Value(A), m_Value(B)))) + return nullptr; + + // The select condition is a compare instruction. If the select's true/false + // values are already the same as the compare operands, there's nothing to do. + if (TVal == A || TVal == B || FVal == A || FVal == B) + return nullptr; + + Value *C, *D; + if (!match(A, m_BitCast(m_Value(C))) || !match(B, m_BitCast(m_Value(D)))) + return nullptr; + + // select (cmp (bitcast C), (bitcast D)), (bitcast TSrc), (bitcast FSrc) + Value *TSrc, *FSrc; + if (!match(TVal, m_BitCast(m_Value(TSrc))) || + !match(FVal, m_BitCast(m_Value(FSrc)))) + return nullptr; + + // If the select true/false values are *different bitcasts* of the same source + // operands, make the select operands the same as the compare operands and + // cast the result. This is the canonical select form for min/max. + Value *NewSel; + if (TSrc == C && FSrc == D) { + // select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) --> + // bitcast (select (cmp A, B), A, B) + NewSel = Builder.CreateSelect(Cond, A, B, "", &Sel); + } else if (TSrc == D && FSrc == C) { + // select (cmp (bitcast C), (bitcast D)), (bitcast' D), (bitcast' C) --> + // bitcast (select (cmp A, B), B, A) + NewSel = Builder.CreateSelect(Cond, B, A, "", &Sel); + } else { + return nullptr; + } + return CastInst::CreateBitOrPointerCast(NewSel, Sel.getType()); +} + +/// Try to eliminate select instructions that test the returned flag of cmpxchg +/// instructions. +/// +/// If a select instruction tests the returned flag of a cmpxchg instruction and +/// selects between the returned value of the cmpxchg instruction its compare +/// operand, the result of the select will always be equal to its false value. +/// For example: +/// +/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %1 = extractvalue { i64, i1 } %0, 1 +/// %2 = extractvalue { i64, i1 } %0, 0 +/// %3 = select i1 %1, i64 %compare, i64 %2 +/// ret i64 %3 +/// +/// The returned value of the cmpxchg instruction (%2) is the original value +/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %2 +/// must have been equal to %compare. Thus, the result of the select is always +/// equal to %2, and the code can be simplified to: +/// +/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %1 = extractvalue { i64, i1 } %0, 0 +/// ret i64 %1 +/// +static Value *foldSelectCmpXchg(SelectInst &SI) { + // A helper that determines if V is an extractvalue instruction whose + // aggregate operand is a cmpxchg instruction and whose single index is equal + // to I. If such conditions are true, the helper returns the cmpxchg + // instruction; otherwise, a nullptr is returned. + auto isExtractFromCmpXchg = [](Value *V, unsigned I) -> AtomicCmpXchgInst * { + auto *Extract = dyn_cast<ExtractValueInst>(V); + if (!Extract) + return nullptr; + if (Extract->getIndices()[0] != I) + return nullptr; + return dyn_cast<AtomicCmpXchgInst>(Extract->getAggregateOperand()); + }; + + // If the select has a single user, and this user is a select instruction that + // we can simplify, skip the cmpxchg simplification for now. + if (SI.hasOneUse()) + if (auto *Select = dyn_cast<SelectInst>(SI.user_back())) + if (Select->getCondition() == SI.getCondition()) + if (Select->getFalseValue() == SI.getTrueValue() || + Select->getTrueValue() == SI.getFalseValue()) + return nullptr; + + // Ensure the select condition is the returned flag of a cmpxchg instruction. + auto *CmpXchg = isExtractFromCmpXchg(SI.getCondition(), 1); + if (!CmpXchg) + return nullptr; + + // Check the true value case: The true value of the select is the returned + // value of the same cmpxchg used by the condition, and the false value is the + // cmpxchg instruction's compare operand. + if (auto *X = isExtractFromCmpXchg(SI.getTrueValue(), 0)) + if (X == CmpXchg && X->getCompareOperand() == SI.getFalseValue()) + return SI.getFalseValue(); + + // Check the false value case: The false value of the select is the returned + // value of the same cmpxchg used by the condition, and the true value is the + // cmpxchg instruction's compare operand. + if (auto *X = isExtractFromCmpXchg(SI.getFalseValue(), 0)) + if (X == CmpXchg && X->getCompareOperand() == SI.getTrueValue()) + return SI.getFalseValue(); + + return nullptr; +} + +/// Try to reduce a funnel/rotate pattern that includes a compare and select +/// into a funnel shift intrinsic. Example: +/// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b))) +/// --> call llvm.fshl.i32(a, a, b) +/// fshl32(a, b, c) --> (c == 0 ? a : ((b >> (32 - c)) | (a << c))) +/// --> call llvm.fshl.i32(a, b, c) +/// fshr32(a, b, c) --> (c == 0 ? b : ((a >> (32 - c)) | (b << c))) +/// --> call llvm.fshr.i32(a, b, c) +static Instruction *foldSelectFunnelShift(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + // This must be a power-of-2 type for a bitmasking transform to be valid. + unsigned Width = Sel.getType()->getScalarSizeInBits(); + if (!isPowerOf2_32(Width)) + return nullptr; + + BinaryOperator *Or0, *Or1; + if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1))))) + return nullptr; + + Value *SV0, *SV1, *SA0, *SA1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(SV0), + m_ZExtOrSelf(m_Value(SA0))))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Value(SV1), + m_ZExtOrSelf(m_Value(SA1))))) || + Or0->getOpcode() == Or1->getOpcode()) + return nullptr; + + // Canonicalize to or(shl(SV0, SA0), lshr(SV1, SA1)). + if (Or0->getOpcode() == BinaryOperator::LShr) { + std::swap(Or0, Or1); + std::swap(SV0, SV1); + std::swap(SA0, SA1); + } + assert(Or0->getOpcode() == BinaryOperator::Shl && + Or1->getOpcode() == BinaryOperator::LShr && + "Illegal or(shift,shift) pair"); + + // Check the shift amounts to see if they are an opposite pair. + Value *ShAmt; + if (match(SA1, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA0))))) + ShAmt = SA0; + else if (match(SA0, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA1))))) + ShAmt = SA1; + else + return nullptr; + + // We should now have this pattern: + // select ?, TVal, (or (shl SV0, SA0), (lshr SV1, SA1)) + // The false value of the select must be a funnel-shift of the true value: + // IsFShl -> TVal must be SV0 else TVal must be SV1. + bool IsFshl = (ShAmt == SA0); + Value *TVal = Sel.getTrueValue(); + if ((IsFshl && TVal != SV0) || (!IsFshl && TVal != SV1)) + return nullptr; + + // Finally, see if the select is filtering out a shift-by-zero. + Value *Cond = Sel.getCondition(); + ICmpInst::Predicate Pred; + if (!match(Cond, m_OneUse(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()))) || + Pred != ICmpInst::ICMP_EQ) + return nullptr; + + // If this is not a rotate then the select was blocking poison from the + // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it. + if (SV0 != SV1) { + if (IsFshl && !llvm::isGuaranteedNotToBePoison(SV1)) + SV1 = Builder.CreateFreeze(SV1); + else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(SV0)) + SV0 = Builder.CreateFreeze(SV0); + } + + // This is a funnel/rotate that avoids shift-by-bitwidth UB in a suboptimal way. + // Convert to funnel shift intrinsic. + Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; + Function *F = Intrinsic::getDeclaration(Sel.getModule(), IID, Sel.getType()); + ShAmt = Builder.CreateZExt(ShAmt, Sel.getType()); + return CallInst::Create(F, { SV0, SV1, ShAmt }); +} + +static Instruction *foldSelectToCopysign(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + Value *Cond = Sel.getCondition(); + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + Type *SelType = Sel.getType(); + + // Match select ?, TC, FC where the constants are equal but negated. + // TODO: Generalize to handle a negated variable operand? + const APFloat *TC, *FC; + if (!match(TVal, m_APFloatAllowUndef(TC)) || + !match(FVal, m_APFloatAllowUndef(FC)) || + !abs(*TC).bitwiseIsEqual(abs(*FC))) + return nullptr; + + assert(TC != FC && "Expected equal select arms to simplify"); + + Value *X; + const APInt *C; + bool IsTrueIfSignSet; + ICmpInst::Predicate Pred; + if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) || + !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) || + X->getType() != SelType) + return nullptr; + + // If needed, negate the value that will be the sign argument of the copysign: + // (bitcast X) < 0 ? -TC : TC --> copysign(TC, X) + // (bitcast X) < 0 ? TC : -TC --> copysign(TC, -X) + // (bitcast X) >= 0 ? -TC : TC --> copysign(TC, -X) + // (bitcast X) >= 0 ? TC : -TC --> copysign(TC, X) + // Note: FMF from the select can not be propagated to the new instructions. + if (IsTrueIfSignSet ^ TC->isNegative()) + X = Builder.CreateFNeg(X); + + // Canonicalize the magnitude argument as the positive constant since we do + // not care about its sign. + Value *MagArg = ConstantFP::get(SelType, abs(*TC)); + Function *F = Intrinsic::getDeclaration(Sel.getModule(), Intrinsic::copysign, + Sel.getType()); + return CallInst::Create(F, { MagArg, X }); +} + +Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { + auto *VecTy = dyn_cast<FixedVectorType>(Sel.getType()); + if (!VecTy) + return nullptr; + + unsigned NumElts = VecTy->getNumElements(); + APInt UndefElts(NumElts, 0); + APInt AllOnesEltMask(APInt::getAllOnes(NumElts)); + if (Value *V = SimplifyDemandedVectorElts(&Sel, AllOnesEltMask, UndefElts)) { + if (V != &Sel) + return replaceInstUsesWith(Sel, V); + return &Sel; + } + + // A select of a "select shuffle" with a common operand can be rearranged + // to select followed by "select shuffle". Because of poison, this only works + // in the case of a shuffle with no undefined mask elements. + Value *Cond = Sel.getCondition(); + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + Value *X, *Y; + ArrayRef<int> Mask; + if (match(TVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && + !is_contained(Mask, UndefMaskElem) && + cast<ShuffleVectorInst>(TVal)->isSelect()) { + if (X == FVal) { + // select Cond, (shuf_sel X, Y), X --> shuf_sel X, (select Cond, Y, X) + Value *NewSel = Builder.CreateSelect(Cond, Y, X, "sel", &Sel); + return new ShuffleVectorInst(X, NewSel, Mask); + } + if (Y == FVal) { + // select Cond, (shuf_sel X, Y), Y --> shuf_sel (select Cond, X, Y), Y + Value *NewSel = Builder.CreateSelect(Cond, X, Y, "sel", &Sel); + return new ShuffleVectorInst(NewSel, Y, Mask); + } + } + if (match(FVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && + !is_contained(Mask, UndefMaskElem) && + cast<ShuffleVectorInst>(FVal)->isSelect()) { + if (X == TVal) { + // select Cond, X, (shuf_sel X, Y) --> shuf_sel X, (select Cond, X, Y) + Value *NewSel = Builder.CreateSelect(Cond, X, Y, "sel", &Sel); + return new ShuffleVectorInst(X, NewSel, Mask); + } + if (Y == TVal) { + // select Cond, Y, (shuf_sel X, Y) --> shuf_sel (select Cond, Y, X), Y + Value *NewSel = Builder.CreateSelect(Cond, Y, X, "sel", &Sel); + return new ShuffleVectorInst(NewSel, Y, Mask); + } + } + + return nullptr; +} + +static Instruction *foldSelectToPhiImpl(SelectInst &Sel, BasicBlock *BB, + const DominatorTree &DT, + InstCombiner::BuilderTy &Builder) { + // Find the block's immediate dominator that ends with a conditional branch + // that matches select's condition (maybe inverted). + auto *IDomNode = DT[BB]->getIDom(); + if (!IDomNode) + return nullptr; + BasicBlock *IDom = IDomNode->getBlock(); + + Value *Cond = Sel.getCondition(); + Value *IfTrue, *IfFalse; + BasicBlock *TrueSucc, *FalseSucc; + if (match(IDom->getTerminator(), + m_Br(m_Specific(Cond), m_BasicBlock(TrueSucc), + m_BasicBlock(FalseSucc)))) { + IfTrue = Sel.getTrueValue(); + IfFalse = Sel.getFalseValue(); + } else if (match(IDom->getTerminator(), + m_Br(m_Not(m_Specific(Cond)), m_BasicBlock(TrueSucc), + m_BasicBlock(FalseSucc)))) { + IfTrue = Sel.getFalseValue(); + IfFalse = Sel.getTrueValue(); + } else + return nullptr; + + // Make sure the branches are actually different. + if (TrueSucc == FalseSucc) + return nullptr; + + // We want to replace select %cond, %a, %b with a phi that takes value %a + // for all incoming edges that are dominated by condition `%cond == true`, + // and value %b for edges dominated by condition `%cond == false`. If %a + // or %b are also phis from the same basic block, we can go further and take + // their incoming values from the corresponding blocks. + BasicBlockEdge TrueEdge(IDom, TrueSucc); + BasicBlockEdge FalseEdge(IDom, FalseSucc); + DenseMap<BasicBlock *, Value *> Inputs; + for (auto *Pred : predecessors(BB)) { + // Check implication. + BasicBlockEdge Incoming(Pred, BB); + if (DT.dominates(TrueEdge, Incoming)) + Inputs[Pred] = IfTrue->DoPHITranslation(BB, Pred); + else if (DT.dominates(FalseEdge, Incoming)) + Inputs[Pred] = IfFalse->DoPHITranslation(BB, Pred); + else + return nullptr; + // Check availability. + if (auto *Insn = dyn_cast<Instruction>(Inputs[Pred])) + if (!DT.dominates(Insn, Pred->getTerminator())) + return nullptr; + } + + Builder.SetInsertPoint(&*BB->begin()); + auto *PN = Builder.CreatePHI(Sel.getType(), Inputs.size()); + for (auto *Pred : predecessors(BB)) + PN->addIncoming(Inputs[Pred], Pred); + PN->takeName(&Sel); + return PN; +} + +static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT, + InstCombiner::BuilderTy &Builder) { + // Try to replace this select with Phi in one of these blocks. + SmallSetVector<BasicBlock *, 4> CandidateBlocks; + CandidateBlocks.insert(Sel.getParent()); + for (Value *V : Sel.operands()) + if (auto *I = dyn_cast<Instruction>(V)) + CandidateBlocks.insert(I->getParent()); + + for (BasicBlock *BB : CandidateBlocks) + if (auto *PN = foldSelectToPhiImpl(Sel, BB, DT, Builder)) + return PN; + return nullptr; +} + +static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) { + FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition()); + if (!FI) + return nullptr; + + Value *Cond = FI->getOperand(0); + Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue(); + + // select (freeze(x == y)), x, y --> y + // select (freeze(x != y)), x, y --> x + // The freeze should be only used by this select. Otherwise, remaining uses of + // the freeze can observe a contradictory value. + // c = freeze(x == y) ; Let's assume that y = poison & x = 42; c is 0 or 1 + // a = select c, x, y ; + // f(a, c) ; f(poison, 1) cannot happen, but if a is folded + // ; to y, this can happen. + CmpInst::Predicate Pred; + if (FI->hasOneUse() && + match(Cond, m_c_ICmp(Pred, m_Specific(TrueVal), m_Specific(FalseVal))) && + (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)) { + return Pred == ICmpInst::ICMP_EQ ? FalseVal : TrueVal; + } + + return nullptr; +} + +Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op, + SelectInst &SI, + bool IsAnd) { + Value *CondVal = SI.getCondition(); + Value *A = SI.getTrueValue(); + Value *B = SI.getFalseValue(); + + assert(Op->getType()->isIntOrIntVectorTy(1) && + "Op must be either i1 or vector of i1."); + + Optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd); + if (!Res) + return nullptr; + + Value *Zero = Constant::getNullValue(A->getType()); + Value *One = Constant::getAllOnesValue(A->getType()); + + if (*Res == true) { + if (IsAnd) + // select op, (select cond, A, B), false => select op, A, false + // and op, (select cond, A, B) => select op, A, false + // if op = true implies condval = true. + return SelectInst::Create(Op, A, Zero); + else + // select op, true, (select cond, A, B) => select op, true, A + // or op, (select cond, A, B) => select op, true, A + // if op = false implies condval = true. + return SelectInst::Create(Op, One, A); + } else { + if (IsAnd) + // select op, (select cond, A, B), false => select op, B, false + // and op, (select cond, A, B) => select op, B, false + // if op = true implies condval = false. + return SelectInst::Create(Op, B, Zero); + else + // select op, true, (select cond, A, B) => select op, true, B + // or op, (select cond, A, B) => select op, true, B + // if op = false implies condval = false. + return SelectInst::Create(Op, One, B); + } +} + +// Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need +// fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. +static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, + InstCombinerImpl &IC) { + Value *CondVal = SI.getCondition(); + + for (bool Swap : {false, true}) { + Value *TrueVal = SI.getTrueValue(); + Value *X = SI.getFalseValue(); + CmpInst::Predicate Pred; + + if (Swap) + std::swap(TrueVal, X); + + if (!match(CondVal, m_FCmp(Pred, m_Specific(X), m_AnyZeroFP()))) + continue; + + // fold (X <= +/-0.0) ? (0.0 - X) : X to fabs(X), when 'Swap' is false + // fold (X > +/-0.0) ? X : (0.0 - X) to fabs(X), when 'Swap' is true + if (match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(X)))) { + if (!Swap && (Pred == FCmpInst::FCMP_OLE || Pred == FCmpInst::FCMP_ULE)) { + Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); + return IC.replaceInstUsesWith(SI, Fabs); + } + if (Swap && (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_UGT)) { + Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); + return IC.replaceInstUsesWith(SI, Fabs); + } + } + + // With nsz, when 'Swap' is false: + // fold (X < +/-0.0) ? -X : X or (X <= +/-0.0) ? -X : X to fabs(X) + // fold (X > +/-0.0) ? -X : X or (X >= +/-0.0) ? -X : X to -fabs(x) + // when 'Swap' is true: + // fold (X > +/-0.0) ? X : -X or (X >= +/-0.0) ? X : -X to fabs(X) + // fold (X < +/-0.0) ? X : -X or (X <= +/-0.0) ? X : -X to -fabs(X) + if (!match(TrueVal, m_FNeg(m_Specific(X))) || !SI.hasNoSignedZeros()) + return nullptr; + + if (Swap) + Pred = FCmpInst::getSwappedPredicate(Pred); + + bool IsLTOrLE = Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE || + Pred == FCmpInst::FCMP_ULT || Pred == FCmpInst::FCMP_ULE; + bool IsGTOrGE = Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE || + Pred == FCmpInst::FCMP_UGT || Pred == FCmpInst::FCMP_UGE; + + if (IsLTOrLE) { + Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); + return IC.replaceInstUsesWith(SI, Fabs); + } + if (IsGTOrGE) { + Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); + Instruction *NewFNeg = UnaryOperator::CreateFNeg(Fabs); + NewFNeg->setFastMathFlags(SI.getFastMathFlags()); + return NewFNeg; + } + } + + return nullptr; +} + +// Match the following IR pattern: +// %x.lowbits = and i8 %x, %lowbitmask +// %x.lowbits.are.zero = icmp eq i8 %x.lowbits, 0 +// %x.biased = add i8 %x, %bias +// %x.biased.highbits = and i8 %x.biased, %highbitmask +// %x.roundedup = select i1 %x.lowbits.are.zero, i8 %x, i8 %x.biased.highbits +// Define: +// %alignment = add i8 %lowbitmask, 1 +// Iff 1. an %alignment is a power-of-two (aka, %lowbitmask is a low bit mask) +// and 2. %bias is equal to either %lowbitmask or %alignment, +// and 3. %highbitmask is equal to ~%lowbitmask (aka, to -%alignment) +// then this pattern can be transformed into: +// %x.offset = add i8 %x, %lowbitmask +// %x.roundedup = and i8 %x.offset, %highbitmask +static Value * +foldRoundUpIntegerWithPow2Alignment(SelectInst &SI, + InstCombiner::BuilderTy &Builder) { + Value *Cond = SI.getCondition(); + Value *X = SI.getTrueValue(); + Value *XBiasedHighBits = SI.getFalseValue(); + + ICmpInst::Predicate Pred; + Value *XLowBits; + if (!match(Cond, m_ICmp(Pred, m_Value(XLowBits), m_ZeroInt())) || + !ICmpInst::isEquality(Pred)) + return nullptr; + + if (Pred == ICmpInst::Predicate::ICMP_NE) + std::swap(X, XBiasedHighBits); + + // FIXME: we could support non non-splats here. + + const APInt *LowBitMaskCst; + if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowUndef(LowBitMaskCst)))) + return nullptr; + + const APInt *BiasCst, *HighBitMaskCst; + if (!match(XBiasedHighBits, + m_And(m_Add(m_Specific(X), m_APIntAllowUndef(BiasCst)), + m_APIntAllowUndef(HighBitMaskCst)))) + return nullptr; + + if (!LowBitMaskCst->isMask()) + return nullptr; + + APInt InvertedLowBitMaskCst = ~*LowBitMaskCst; + if (InvertedLowBitMaskCst != *HighBitMaskCst) + return nullptr; + + APInt AlignmentCst = *LowBitMaskCst + 1; + + if (*BiasCst != AlignmentCst && *BiasCst != *LowBitMaskCst) + return nullptr; + + if (!XBiasedHighBits->hasOneUse()) { + if (*BiasCst == *LowBitMaskCst) + return XBiasedHighBits; + return nullptr; + } + + // FIXME: could we preserve undef's here? + Type *Ty = X->getType(); + Value *XOffset = Builder.CreateAdd(X, ConstantInt::get(Ty, *LowBitMaskCst), + X->getName() + ".biased"); + Value *R = Builder.CreateAnd(XOffset, ConstantInt::get(Ty, *HighBitMaskCst)); + R->takeName(&SI); + return R; +} + +Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + Type *SelType = SI.getType(); + + if (Value *V = simplifySelectInst(CondVal, TrueVal, FalseVal, + SQ.getWithInstruction(&SI))) + return replaceInstUsesWith(SI, V); + + if (Instruction *I = canonicalizeSelectToShuffle(SI)) + return I; + + if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this)) + return I; + + // Avoid potential infinite loops by checking for non-constant condition. + // TODO: Can we assert instead by improving canonicalizeSelectToShuffle()? + // Scalar select must have simplified? + if (SelType->isIntOrIntVectorTy(1) && !isa<Constant>(CondVal) && + TrueVal->getType() == CondVal->getType()) { + // Folding select to and/or i1 isn't poison safe in general. impliesPoison + // checks whether folding it does not convert a well-defined value into + // poison. + if (match(TrueVal, m_One())) { + if (impliesPoison(FalseVal, CondVal)) { + // Change: A = select B, true, C --> A = or B, C + return BinaryOperator::CreateOr(CondVal, FalseVal); + } + + if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) + if (auto *RHS = dyn_cast<FCmpInst>(FalseVal)) + if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false, + /*IsSelectLogical*/ true)) + return replaceInstUsesWith(SI, V); + } + if (match(FalseVal, m_Zero())) { + if (impliesPoison(TrueVal, CondVal)) { + // Change: A = select B, C, false --> A = and B, C + return BinaryOperator::CreateAnd(CondVal, TrueVal); + } + + if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) + if (auto *RHS = dyn_cast<FCmpInst>(TrueVal)) + if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true, + /*IsSelectLogical*/ true)) + return replaceInstUsesWith(SI, V); + } + + auto *One = ConstantInt::getTrue(SelType); + auto *Zero = ConstantInt::getFalse(SelType); + + // We match the "full" 0 or 1 constant here to avoid a potential infinite + // loop with vectors that may have undefined/poison elements. + // select a, false, b -> select !a, b, false + if (match(TrueVal, m_Specific(Zero))) { + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); + return SelectInst::Create(NotCond, FalseVal, Zero); + } + // select a, b, true -> select !a, true, b + if (match(FalseVal, m_Specific(One))) { + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); + return SelectInst::Create(NotCond, One, TrueVal); + } + + // select a, a, b -> select a, true, b + if (CondVal == TrueVal) + return replaceOperand(SI, 1, One); + // select a, b, a -> select a, b, false + if (CondVal == FalseVal) + return replaceOperand(SI, 2, Zero); + + // select a, !a, b -> select !a, b, false + if (match(TrueVal, m_Not(m_Specific(CondVal)))) + return SelectInst::Create(TrueVal, FalseVal, Zero); + // select a, b, !a -> select !a, true, b + if (match(FalseVal, m_Not(m_Specific(CondVal)))) + return SelectInst::Create(FalseVal, One, TrueVal); + + Value *A, *B; + + // DeMorgan in select form: !a && !b --> !(a || b) + // select !a, !b, false --> not (select a, true, b) + if (match(&SI, m_LogicalAnd(m_Not(m_Value(A)), m_Not(m_Value(B)))) && + (CondVal->hasOneUse() || TrueVal->hasOneUse()) && + !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) + return BinaryOperator::CreateNot(Builder.CreateSelect(A, One, B)); + + // DeMorgan in select form: !a || !b --> !(a && b) + // select !a, true, !b --> not (select a, b, false) + if (match(&SI, m_LogicalOr(m_Not(m_Value(A)), m_Not(m_Value(B)))) && + (CondVal->hasOneUse() || FalseVal->hasOneUse()) && + !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) + return BinaryOperator::CreateNot(Builder.CreateSelect(A, B, Zero)); + + // select (select a, true, b), true, b -> select a, true, b + if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) && + match(TrueVal, m_One()) && match(FalseVal, m_Specific(B))) + return replaceOperand(SI, 0, A); + // select (select a, b, false), b, false -> select a, b, false + if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && + match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) + return replaceOperand(SI, 0, A); + + Value *C; + // select (~a | c), a, b -> and a, (or c, freeze(b)) + if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) && + CondVal->hasOneUse()) { + FalseVal = Builder.CreateFreeze(FalseVal); + return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal)); + } + // select (~c & b), a, b -> and b, (or freeze(a), c) + if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) && + CondVal->hasOneUse()) { + TrueVal = Builder.CreateFreeze(TrueVal); + return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal)); + } + + if (!SelType->isVectorTy()) { + if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, One, SQ, + /* AllowRefinement */ true)) + return replaceOperand(SI, 1, S); + if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ, + /* AllowRefinement */ true)) + return replaceOperand(SI, 2, S); + } + + if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) { + Use *Y = nullptr; + bool IsAnd = match(FalseVal, m_Zero()) ? true : false; + Value *Op1 = IsAnd ? TrueVal : FalseVal; + if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) { + auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr"); + InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser())); + replaceUse(*Y, FI); + return replaceInstUsesWith(SI, Op1); + } + + if (auto *Op1SI = dyn_cast<SelectInst>(Op1)) + if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI, + /* IsAnd */ IsAnd)) + return I; + + if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) + if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) + if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd, + /* IsLogical */ true)) + return replaceInstUsesWith(SI, V); + } + + // select (select a, true, b), c, false -> select a, c, false + // select c, (select a, true, b), false -> select c, a, false + // if c implies that b is false. + if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) && + match(FalseVal, m_Zero())) { + Optional<bool> Res = isImpliedCondition(TrueVal, B, DL); + if (Res && *Res == false) + return replaceOperand(SI, 0, A); + } + if (match(TrueVal, m_Select(m_Value(A), m_One(), m_Value(B))) && + match(FalseVal, m_Zero())) { + Optional<bool> Res = isImpliedCondition(CondVal, B, DL); + if (Res && *Res == false) + return replaceOperand(SI, 1, A); + } + // select c, true, (select a, b, false) -> select c, true, a + // select (select a, b, false), true, c -> select a, true, c + // if c = false implies that b = true + if (match(TrueVal, m_One()) && + match(FalseVal, m_Select(m_Value(A), m_Value(B), m_Zero()))) { + Optional<bool> Res = isImpliedCondition(CondVal, B, DL, false); + if (Res && *Res == true) + return replaceOperand(SI, 2, A); + } + if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && + match(TrueVal, m_One())) { + Optional<bool> Res = isImpliedCondition(FalseVal, B, DL, false); + if (Res && *Res == true) + return replaceOperand(SI, 0, A); + } + + // sel (sel c, a, false), true, (sel !c, b, false) -> sel c, a, b + // sel (sel !c, a, false), true, (sel c, b, false) -> sel c, b, a + Value *C1, *C2; + if (match(CondVal, m_Select(m_Value(C1), m_Value(A), m_Zero())) && + match(TrueVal, m_One()) && + match(FalseVal, m_Select(m_Value(C2), m_Value(B), m_Zero()))) { + if (match(C2, m_Not(m_Specific(C1)))) // first case + return SelectInst::Create(C1, A, B); + else if (match(C1, m_Not(m_Specific(C2)))) // second case + return SelectInst::Create(C2, B, A); + } + } + + // Selecting between two integer or vector splat integer constants? + // + // Note that we don't handle a scalar select of vectors: + // select i1 %c, <2 x i8> <1, 1>, <2 x i8> <0, 0> + // because that may need 3 instructions to splat the condition value: + // extend, insertelement, shufflevector. + // + // Do not handle i1 TrueVal and FalseVal otherwise would result in + // zext/sext i1 to i1. + if (SelType->isIntOrIntVectorTy() && !SelType->isIntOrIntVectorTy(1) && + CondVal->getType()->isVectorTy() == SelType->isVectorTy()) { + // select C, 1, 0 -> zext C to int + if (match(TrueVal, m_One()) && match(FalseVal, m_Zero())) + return new ZExtInst(CondVal, SelType); + + // select C, -1, 0 -> sext C to int + if (match(TrueVal, m_AllOnes()) && match(FalseVal, m_Zero())) + return new SExtInst(CondVal, SelType); + + // select C, 0, 1 -> zext !C to int + if (match(TrueVal, m_Zero()) && match(FalseVal, m_One())) { + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); + return new ZExtInst(NotCond, SelType); + } + + // select C, 0, -1 -> sext !C to int + if (match(TrueVal, m_Zero()) && match(FalseVal, m_AllOnes())) { + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); + return new SExtInst(NotCond, SelType); + } + } + + if (auto *FCmp = dyn_cast<FCmpInst>(CondVal)) { + Value *Cmp0 = FCmp->getOperand(0), *Cmp1 = FCmp->getOperand(1); + // Are we selecting a value based on a comparison of the two values? + if ((Cmp0 == TrueVal && Cmp1 == FalseVal) || + (Cmp0 == FalseVal && Cmp1 == TrueVal)) { + // Canonicalize to use ordered comparisons by swapping the select + // operands. + // + // e.g. + // (X ugt Y) ? X : Y -> (X ole Y) ? Y : X + if (FCmp->hasOneUse() && FCmpInst::isUnordered(FCmp->getPredicate())) { + FCmpInst::Predicate InvPred = FCmp->getInversePredicate(); + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + // FIXME: The FMF should propagate from the select, not the fcmp. + Builder.setFastMathFlags(FCmp->getFastMathFlags()); + Value *NewCond = Builder.CreateFCmp(InvPred, Cmp0, Cmp1, + FCmp->getName() + ".inv"); + Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal); + return replaceInstUsesWith(SI, NewSel); + } + + // NOTE: if we wanted to, this is where to detect MIN/MAX + } + } + + // Fold selecting to fabs. + if (Instruction *Fabs = foldSelectWithFCmpToFabs(SI, *this)) + return Fabs; + + // See if we are selecting two values based on a comparison of the two values. + if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal)) + if (Instruction *Result = foldSelectInstWithICmp(SI, ICI)) + return Result; + + if (Instruction *Add = foldAddSubSelect(SI, Builder)) + return Add; + if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder)) + return Add; + if (Instruction *Or = foldSetClearBits(SI, Builder)) + return Or; + if (Instruction *Mul = foldSelectZeroOrMul(SI, *this)) + return Mul; + + // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) + auto *TI = dyn_cast<Instruction>(TrueVal); + auto *FI = dyn_cast<Instruction>(FalseVal); + if (TI && FI && TI->getOpcode() == FI->getOpcode()) + if (Instruction *IV = foldSelectOpOp(SI, TI, FI)) + return IV; + + if (Instruction *I = foldSelectExtConst(SI)) + return I; + + // Fold (select C, (gep Ptr, Idx), Ptr) -> (gep Ptr, (select C, Idx, 0)) + // Fold (select C, Ptr, (gep Ptr, Idx)) -> (gep Ptr, (select C, 0, Idx)) + auto SelectGepWithBase = [&](GetElementPtrInst *Gep, Value *Base, + bool Swap) -> GetElementPtrInst * { + Value *Ptr = Gep->getPointerOperand(); + if (Gep->getNumOperands() != 2 || Gep->getPointerOperand() != Base || + !Gep->hasOneUse()) + return nullptr; + Value *Idx = Gep->getOperand(1); + if (isa<VectorType>(CondVal->getType()) && !isa<VectorType>(Idx->getType())) + return nullptr; + Type *ElementType = Gep->getResultElementType(); + Value *NewT = Idx; + Value *NewF = Constant::getNullValue(Idx->getType()); + if (Swap) + std::swap(NewT, NewF); + Value *NewSI = + Builder.CreateSelect(CondVal, NewT, NewF, SI.getName() + ".idx", &SI); + return GetElementPtrInst::Create(ElementType, Ptr, {NewSI}); + }; + if (auto *TrueGep = dyn_cast<GetElementPtrInst>(TrueVal)) + if (auto *NewGep = SelectGepWithBase(TrueGep, FalseVal, false)) + return NewGep; + if (auto *FalseGep = dyn_cast<GetElementPtrInst>(FalseVal)) + if (auto *NewGep = SelectGepWithBase(FalseGep, TrueVal, true)) + return NewGep; + + // See if we can fold the select into one of our operands. + if (SelType->isIntOrIntVectorTy() || SelType->isFPOrFPVectorTy()) { + if (Instruction *FoldI = foldSelectIntoOp(SI, TrueVal, FalseVal)) + return FoldI; + + Value *LHS, *RHS; + Instruction::CastOps CastOp; + SelectPatternResult SPR = matchSelectPattern(&SI, LHS, RHS, &CastOp); + auto SPF = SPR.Flavor; + if (SPF) { + Value *LHS2, *RHS2; + if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor) + if (Instruction *R = foldSPFofSPF(cast<Instruction>(LHS), SPF2, LHS2, + RHS2, SI, SPF, RHS)) + return R; + if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2).Flavor) + if (Instruction *R = foldSPFofSPF(cast<Instruction>(RHS), SPF2, LHS2, + RHS2, SI, SPF, LHS)) + return R; + } + + if (SelectPatternResult::isMinOrMax(SPF)) { + // Canonicalize so that + // - type casts are outside select patterns. + // - float clamp is transformed to min/max pattern + + bool IsCastNeeded = LHS->getType() != SelType; + Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0); + Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1); + if (IsCastNeeded || + (LHS->getType()->isFPOrFPVectorTy() && + ((CmpLHS != LHS && CmpLHS != RHS) || + (CmpRHS != LHS && CmpRHS != RHS)))) { + CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered); + + Value *Cmp; + if (CmpInst::isIntPredicate(MinMaxPred)) { + Cmp = Builder.CreateICmp(MinMaxPred, LHS, RHS); + } else { + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + auto FMF = + cast<FPMathOperator>(SI.getCondition())->getFastMathFlags(); + Builder.setFastMathFlags(FMF); + Cmp = Builder.CreateFCmp(MinMaxPred, LHS, RHS); + } + + Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI); + if (!IsCastNeeded) + return replaceInstUsesWith(SI, NewSI); + + Value *NewCast = Builder.CreateCast(CastOp, NewSI, SelType); + return replaceInstUsesWith(SI, NewCast); + } + } + } + + // Canonicalize select of FP values where NaN and -0.0 are not valid as + // minnum/maxnum intrinsics. + if (isa<FPMathOperator>(SI) && SI.hasNoNaNs() && SI.hasNoSignedZeros()) { + Value *X, *Y; + if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y)))) + return replaceInstUsesWith( + SI, Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, X, Y, &SI)); + + if (match(&SI, m_OrdFMin(m_Value(X), m_Value(Y)))) + return replaceInstUsesWith( + SI, Builder.CreateBinaryIntrinsic(Intrinsic::minnum, X, Y, &SI)); + } + + // See if we can fold the select into a phi node if the condition is a select. + if (auto *PN = dyn_cast<PHINode>(SI.getCondition())) + // The true/false values have to be live in the PHI predecessor's blocks. + if (canSelectOperandBeMappingIntoPredBlock(TrueVal, SI) && + canSelectOperandBeMappingIntoPredBlock(FalseVal, SI)) + if (Instruction *NV = foldOpIntoPhi(SI, PN)) + return NV; + + if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) { + if (TrueSI->getCondition()->getType() == CondVal->getType()) { + // select(C, select(C, a, b), c) -> select(C, a, c) + if (TrueSI->getCondition() == CondVal) { + if (SI.getTrueValue() == TrueSI->getTrueValue()) + return nullptr; + return replaceOperand(SI, 1, TrueSI->getTrueValue()); + } + // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b) + // We choose this as normal form to enable folding on the And and + // shortening paths for the values (this helps getUnderlyingObjects() for + // example). + if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) { + Value *And = Builder.CreateLogicalAnd(CondVal, TrueSI->getCondition()); + replaceOperand(SI, 0, And); + replaceOperand(SI, 1, TrueSI->getTrueValue()); + return &SI; + } + } + } + if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) { + if (FalseSI->getCondition()->getType() == CondVal->getType()) { + // select(C, a, select(C, b, c)) -> select(C, a, c) + if (FalseSI->getCondition() == CondVal) { + if (SI.getFalseValue() == FalseSI->getFalseValue()) + return nullptr; + return replaceOperand(SI, 2, FalseSI->getFalseValue()); + } + // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b) + if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) { + Value *Or = Builder.CreateLogicalOr(CondVal, FalseSI->getCondition()); + replaceOperand(SI, 0, Or); + replaceOperand(SI, 2, FalseSI->getFalseValue()); + return &SI; + } + } + } + + auto canMergeSelectThroughBinop = [](BinaryOperator *BO) { + // The select might be preventing a division by 0. + switch (BO->getOpcode()) { + default: + return true; + case Instruction::SRem: + case Instruction::URem: + case Instruction::SDiv: + case Instruction::UDiv: + return false; + } + }; + + // Try to simplify a binop sandwiched between 2 selects with the same + // condition. + // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z) + BinaryOperator *TrueBO; + if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && + canMergeSelectThroughBinop(TrueBO)) { + if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { + if (TrueBOSI->getCondition() == CondVal) { + replaceOperand(*TrueBO, 0, TrueBOSI->getTrueValue()); + Worklist.push(TrueBO); + return &SI; + } + } + if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(1))) { + if (TrueBOSI->getCondition() == CondVal) { + replaceOperand(*TrueBO, 1, TrueBOSI->getTrueValue()); + Worklist.push(TrueBO); + return &SI; + } + } + } + + // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W)) + BinaryOperator *FalseBO; + if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && + canMergeSelectThroughBinop(FalseBO)) { + if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { + if (FalseBOSI->getCondition() == CondVal) { + replaceOperand(*FalseBO, 0, FalseBOSI->getFalseValue()); + Worklist.push(FalseBO); + return &SI; + } + } + if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(1))) { + if (FalseBOSI->getCondition() == CondVal) { + replaceOperand(*FalseBO, 1, FalseBOSI->getFalseValue()); + Worklist.push(FalseBO); + return &SI; + } + } + } + + Value *NotCond; + if (match(CondVal, m_Not(m_Value(NotCond))) && + !InstCombiner::shouldAvoidAbsorbingNotIntoSelect(SI)) { + replaceOperand(SI, 0, NotCond); + SI.swapValues(); + SI.swapProfMetadata(); + return &SI; + } + + if (Instruction *I = foldVectorSelect(SI)) + return I; + + // If we can compute the condition, there's no need for a select. + // Like the above fold, we are attempting to reduce compile-time cost by + // putting this fold here with limitations rather than in InstSimplify. + // The motivation for this call into value tracking is to take advantage of + // the assumption cache, so make sure that is populated. + if (!CondVal->getType()->isVectorTy() && !AC.assumptions().empty()) { + KnownBits Known(1); + computeKnownBits(CondVal, Known, 0, &SI); + if (Known.One.isOne()) + return replaceInstUsesWith(SI, TrueVal); + if (Known.Zero.isOne()) + return replaceInstUsesWith(SI, FalseVal); + } + + if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, Builder)) + return BitCastSel; + + // Simplify selects that test the returned flag of cmpxchg instructions. + if (Value *V = foldSelectCmpXchg(SI)) + return replaceInstUsesWith(SI, V); + + if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI, *this)) + return Select; + + if (Instruction *Funnel = foldSelectFunnelShift(SI, Builder)) + return Funnel; + + if (Instruction *Copysign = foldSelectToCopysign(SI, Builder)) + return Copysign; + + if (Instruction *PN = foldSelectToPhi(SI, DT, Builder)) + return replaceInstUsesWith(SI, PN); + + if (Value *Fr = foldSelectWithFrozenICmp(SI, Builder)) + return replaceInstUsesWith(SI, Fr); + + if (Value *V = foldRoundUpIntegerWithPow2Alignment(SI, Builder)) + return replaceInstUsesWith(SI, V); + + // select(mask, mload(,,mask,0), 0) -> mload(,,mask,0) + // Load inst is intentionally not checked for hasOneUse() + if (match(FalseVal, m_Zero()) && + (match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal), + m_CombineOr(m_Undef(), m_Zero()))) || + match(TrueVal, m_MaskedGather(m_Value(), m_Value(), m_Specific(CondVal), + m_CombineOr(m_Undef(), m_Zero()))))) { + auto *MaskedInst = cast<IntrinsicInst>(TrueVal); + if (isa<UndefValue>(MaskedInst->getArgOperand(3))) + MaskedInst->setArgOperand(3, FalseVal /* Zero */); + return replaceInstUsesWith(SI, MaskedInst); + } + + Value *Mask; + if (match(TrueVal, m_Zero()) && + (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask), + m_CombineOr(m_Undef(), m_Zero()))) || + match(FalseVal, m_MaskedGather(m_Value(), m_Value(), m_Value(Mask), + m_CombineOr(m_Undef(), m_Zero())))) && + (CondVal->getType() == Mask->getType())) { + // We can remove the select by ensuring the load zeros all lanes the + // select would have. We determine this by proving there is no overlap + // between the load and select masks. + // (i.e (load_mask & select_mask) == 0 == no overlap) + bool CanMergeSelectIntoLoad = false; + if (Value *V = simplifyAndInst(CondVal, Mask, SQ.getWithInstruction(&SI))) + CanMergeSelectIntoLoad = match(V, m_Zero()); + + if (CanMergeSelectIntoLoad) { + auto *MaskedInst = cast<IntrinsicInst>(FalseVal); + if (isa<UndefValue>(MaskedInst->getArgOperand(3))) + MaskedInst->setArgOperand(3, TrueVal /* Zero */); + return replaceInstUsesWith(SI, MaskedInst); + } + } + + return nullptr; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp new file mode 100644 index 000000000000..f4e2d1239f0f --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -0,0 +1,1466 @@ +//===- InstCombineShifts.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the visitShl, visitLShr, and visitAShr functions. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "instcombine" + +bool canTryToConstantAddTwoShiftAmounts(Value *Sh0, Value *ShAmt0, Value *Sh1, + Value *ShAmt1) { + // We have two shift amounts from two different shifts. The types of those + // shift amounts may not match. If that's the case let's bailout now.. + if (ShAmt0->getType() != ShAmt1->getType()) + return false; + + // As input, we have the following pattern: + // Sh0 (Sh1 X, Q), K + // We want to rewrite that as: + // Sh x, (Q+K) iff (Q+K) u< bitwidth(x) + // While we know that originally (Q+K) would not overflow + // (because 2 * (N-1) u<= iN -1), we have looked past extensions of + // shift amounts. so it may now overflow in smaller bitwidth. + // To ensure that does not happen, we need to ensure that the total maximal + // shift amount is still representable in that smaller bit width. + unsigned MaximalPossibleTotalShiftAmount = + (Sh0->getType()->getScalarSizeInBits() - 1) + + (Sh1->getType()->getScalarSizeInBits() - 1); + APInt MaximalRepresentableShiftAmount = + APInt::getAllOnes(ShAmt0->getType()->getScalarSizeInBits()); + return MaximalRepresentableShiftAmount.uge(MaximalPossibleTotalShiftAmount); +} + +// Given pattern: +// (x shiftopcode Q) shiftopcode K +// we should rewrite it as +// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) and +// +// This is valid for any shift, but they must be identical, and we must be +// careful in case we have (zext(Q)+zext(K)) and look past extensions, +// (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus. +// +// AnalyzeForSignBitExtraction indicates that we will only analyze whether this +// pattern has any 2 right-shifts that sum to 1 less than original bit width. +Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts( + BinaryOperator *Sh0, const SimplifyQuery &SQ, + bool AnalyzeForSignBitExtraction) { + // Look for a shift of some instruction, ignore zext of shift amount if any. + Instruction *Sh0Op0; + Value *ShAmt0; + if (!match(Sh0, + m_Shift(m_Instruction(Sh0Op0), m_ZExtOrSelf(m_Value(ShAmt0))))) + return nullptr; + + // If there is a truncation between the two shifts, we must make note of it + // and look through it. The truncation imposes additional constraints on the + // transform. + Instruction *Sh1; + Value *Trunc = nullptr; + match(Sh0Op0, + m_CombineOr(m_CombineAnd(m_Trunc(m_Instruction(Sh1)), m_Value(Trunc)), + m_Instruction(Sh1))); + + // Inner shift: (x shiftopcode ShAmt1) + // Like with other shift, ignore zext of shift amount if any. + Value *X, *ShAmt1; + if (!match(Sh1, m_Shift(m_Value(X), m_ZExtOrSelf(m_Value(ShAmt1))))) + return nullptr; + + // Verify that it would be safe to try to add those two shift amounts. + if (!canTryToConstantAddTwoShiftAmounts(Sh0, ShAmt0, Sh1, ShAmt1)) + return nullptr; + + // We are only looking for signbit extraction if we have two right shifts. + bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) && + match(Sh1, m_Shr(m_Value(), m_Value())); + // ... and if it's not two right-shifts, we know the answer already. + if (AnalyzeForSignBitExtraction && !HadTwoRightShifts) + return nullptr; + + // The shift opcodes must be identical, unless we are just checking whether + // this pattern can be interpreted as a sign-bit-extraction. + Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode(); + bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode(); + if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction) + return nullptr; + + // If we saw truncation, we'll need to produce extra instruction, + // and for that one of the operands of the shift must be one-use, + // unless of course we don't actually plan to produce any instructions here. + if (Trunc && !AnalyzeForSignBitExtraction && + !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + return nullptr; + + // Can we fold (ShAmt0+ShAmt1) ? + auto *NewShAmt = dyn_cast_or_null<Constant>( + simplifyAddInst(ShAmt0, ShAmt1, /*isNSW=*/false, /*isNUW=*/false, + SQ.getWithInstruction(Sh0))); + if (!NewShAmt) + return nullptr; // Did not simplify. + unsigned NewShAmtBitWidth = NewShAmt->getType()->getScalarSizeInBits(); + unsigned XBitWidth = X->getType()->getScalarSizeInBits(); + // Is the new shift amount smaller than the bit width of inner/new shift? + if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, + APInt(NewShAmtBitWidth, XBitWidth)))) + return nullptr; // FIXME: could perform constant-folding. + + // If there was a truncation, and we have a right-shift, we can only fold if + // we are left with the original sign bit. Likewise, if we were just checking + // that this is a sighbit extraction, this is the place to check it. + // FIXME: zero shift amount is also legal here, but we can't *easily* check + // more than one predicate so it's not really worth it. + if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) { + // If it's not a sign bit extraction, then we're done. + if (!match(NewShAmt, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(NewShAmtBitWidth, XBitWidth - 1)))) + return nullptr; + // If it is, and that was the question, return the base value. + if (AnalyzeForSignBitExtraction) + return X; + } + + assert(IdenticalShOpcodes && "Should not get here with different shifts."); + + // All good, we can do this fold. + NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType()); + + BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt); + + // The flags can only be propagated if there wasn't a trunc. + if (!Trunc) { + // If the pattern did not involve trunc, and both of the original shifts + // had the same flag set, preserve the flag. + if (ShiftOpcode == Instruction::BinaryOps::Shl) { + NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() && + Sh1->hasNoUnsignedWrap()); + NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() && + Sh1->hasNoSignedWrap()); + } else { + NewShift->setIsExact(Sh0->isExact() && Sh1->isExact()); + } + } + + Instruction *Ret = NewShift; + if (Trunc) { + Builder.Insert(NewShift); + Ret = CastInst::Create(Instruction::Trunc, NewShift, Sh0->getType()); + } + + return Ret; +} + +// If we have some pattern that leaves only some low bits set, and then performs +// left-shift of those bits, if none of the bits that are left after the final +// shift are modified by the mask, we can omit the mask. +// +// There are many variants to this pattern: +// a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt +// b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt +// c) (x & (-1 l>> MaskShAmt)) << ShiftShAmt +// d) (x & ((-1 << MaskShAmt) l>> MaskShAmt)) << ShiftShAmt +// e) ((x << MaskShAmt) l>> MaskShAmt) << ShiftShAmt +// f) ((x << MaskShAmt) a>> MaskShAmt) << ShiftShAmt +// All these patterns can be simplified to just: +// x << ShiftShAmt +// iff: +// a,b) (MaskShAmt+ShiftShAmt) u>= bitwidth(x) +// c,d,e,f) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt) +static Instruction * +dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, + const SimplifyQuery &Q, + InstCombiner::BuilderTy &Builder) { + assert(OuterShift->getOpcode() == Instruction::BinaryOps::Shl && + "The input must be 'shl'!"); + + Value *Masked, *ShiftShAmt; + match(OuterShift, + m_Shift(m_Value(Masked), m_ZExtOrSelf(m_Value(ShiftShAmt)))); + + // *If* there is a truncation between an outer shift and a possibly-mask, + // then said truncation *must* be one-use, else we can't perform the fold. + Value *Trunc; + if (match(Masked, m_CombineAnd(m_Trunc(m_Value(Masked)), m_Value(Trunc))) && + !Trunc->hasOneUse()) + return nullptr; + + Type *NarrowestTy = OuterShift->getType(); + Type *WidestTy = Masked->getType(); + bool HadTrunc = WidestTy != NarrowestTy; + + // The mask must be computed in a type twice as wide to ensure + // that no bits are lost if the sum-of-shifts is wider than the base type. + Type *ExtendedTy = WidestTy->getExtendedType(); + + Value *MaskShAmt; + + // ((1 << MaskShAmt) - 1) + auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes()); + // (~(-1 << maskNbits)) + auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes()); + // (-1 l>> MaskShAmt) + auto MaskC = m_LShr(m_AllOnes(), m_Value(MaskShAmt)); + // ((-1 << MaskShAmt) l>> MaskShAmt) + auto MaskD = + m_LShr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)); + + Value *X; + Constant *NewMask; + + if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) { + // Peek through an optional zext of the shift amount. + match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt))); + + // Verify that it would be safe to try to add those two shift amounts. + if (!canTryToConstantAddTwoShiftAmounts(OuterShift, ShiftShAmt, Masked, + MaskShAmt)) + return nullptr; + + // Can we simplify (MaskShAmt+ShiftShAmt) ? + auto *SumOfShAmts = dyn_cast_or_null<Constant>(simplifyAddInst( + MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); + if (!SumOfShAmts) + return nullptr; // Did not simplify. + // In this pattern SumOfShAmts correlates with the number of low bits + // that shall remain in the root value (OuterShift). + + // An extend of an undef value becomes zero because the high bits are never + // completely unknown. Replace the `undef` shift amounts with final + // shift bitwidth to ensure that the value remains undef when creating the + // subsequent shift op. + SumOfShAmts = Constant::replaceUndefsWith( + SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(), + ExtendedTy->getScalarSizeInBits())); + auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy); + // And compute the mask as usual: ~(-1 << (SumOfShAmts)) + auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); + auto *ExtendedInvertedMask = + ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts); + NewMask = ConstantExpr::getNot(ExtendedInvertedMask); + } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) || + match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)), + m_Deferred(MaskShAmt)))) { + // Peek through an optional zext of the shift amount. + match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt))); + + // Verify that it would be safe to try to add those two shift amounts. + if (!canTryToConstantAddTwoShiftAmounts(OuterShift, ShiftShAmt, Masked, + MaskShAmt)) + return nullptr; + + // Can we simplify (ShiftShAmt-MaskShAmt) ? + auto *ShAmtsDiff = dyn_cast_or_null<Constant>(simplifySubInst( + ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); + if (!ShAmtsDiff) + return nullptr; // Did not simplify. + // In this pattern ShAmtsDiff correlates with the number of high bits that + // shall be unset in the root value (OuterShift). + + // An extend of an undef value becomes zero because the high bits are never + // completely unknown. Replace the `undef` shift amounts with negated + // bitwidth of innermost shift to ensure that the value remains undef when + // creating the subsequent shift op. + unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits(); + ShAmtsDiff = Constant::replaceUndefsWith( + ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), + -WidestTyBitWidth)); + auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt( + ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(), + WidestTyBitWidth, + /*isSigned=*/false), + ShAmtsDiff), + ExtendedTy); + // And compute the mask as usual: (-1 l>> (NumHighBitsToClear)) + auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); + NewMask = + ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear); + } else + return nullptr; // Don't know anything about this pattern. + + NewMask = ConstantExpr::getTrunc(NewMask, NarrowestTy); + + // Does this mask has any unset bits? If not then we can just not apply it. + bool NeedMask = !match(NewMask, m_AllOnes()); + + // If we need to apply a mask, there are several more restrictions we have. + if (NeedMask) { + // The old masking instruction must go away. + if (!Masked->hasOneUse()) + return nullptr; + // The original "masking" instruction must not have been`ashr`. + if (match(Masked, m_AShr(m_Value(), m_Value()))) + return nullptr; + } + + // If we need to apply truncation, let's do it first, since we can. + // We have already ensured that the old truncation will go away. + if (HadTrunc) + X = Builder.CreateTrunc(X, NarrowestTy); + + // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits. + // We didn't change the Type of this outermost shift, so we can just do it. + auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X, + OuterShift->getOperand(1)); + if (!NeedMask) + return NewShift; + + Builder.Insert(NewShift); + return BinaryOperator::Create(Instruction::And, NewShift, NewMask); +} + +/// If we have a shift-by-constant of a bitwise logic op that itself has a +/// shift-by-constant operand with identical opcode, we may be able to convert +/// that into 2 independent shifts followed by the logic op. This eliminates a +/// a use of an intermediate value (reduces dependency chain). +static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.isShift() && "Expected a shift as input"); + auto *LogicInst = dyn_cast<BinaryOperator>(I.getOperand(0)); + if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse()) + return nullptr; + + Constant *C0, *C1; + if (!match(I.getOperand(1), m_Constant(C1))) + return nullptr; + + Instruction::BinaryOps ShiftOpcode = I.getOpcode(); + Type *Ty = I.getType(); + + // Find a matching one-use shift by constant. The fold is not valid if the sum + // of the shift values equals or exceeds bitwidth. + // TODO: Remove the one-use check if the other logic operand (Y) is constant. + Value *X, *Y; + auto matchFirstShift = [&](Value *V) { + APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits()); + return match(V, m_BinOp(ShiftOpcode, m_Value(), m_Value())) && + match(V, m_OneUse(m_Shift(m_Value(X), m_Constant(C0)))) && + match(ConstantExpr::getAdd(C0, C1), + m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); + }; + + // Logic ops are commutative, so check each operand for a match. + if (matchFirstShift(LogicInst->getOperand(0))) + Y = LogicInst->getOperand(1); + else if (matchFirstShift(LogicInst->getOperand(1))) + Y = LogicInst->getOperand(0); + else + return nullptr; + + // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1) + Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1); + Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC); + Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1)); + return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2); +} + +Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { + if (Instruction *Phi = foldBinopWithPhiOperands(I)) + return Phi; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + assert(Op0->getType() == Op1->getType()); + Type *Ty = I.getType(); + + // If the shift amount is a one-use `sext`, we can demote it to `zext`. + Value *Y; + if (match(Op1, m_OneUse(m_SExt(m_Value(Y))))) { + Value *NewExt = Builder.CreateZExt(Y, Ty, Op1->getName()); + return BinaryOperator::Create(I.getOpcode(), Op0, NewExt); + } + + // See if we can fold away this shift. + if (SimplifyDemandedInstructionBits(I)) + return &I; + + // Try to fold constant and into select arguments. + if (isa<Constant>(Op0)) + if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) + if (Instruction *R = FoldOpIntoSelect(I, SI)) + return R; + + if (Constant *CUI = dyn_cast<Constant>(Op1)) + if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) + return Res; + + if (auto *NewShift = cast_or_null<Instruction>( + reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ))) + return NewShift; + + // Pre-shift a constant shifted by a variable amount with constant offset: + // C shift (A add nuw C1) --> (C shift C1) shift A + Value *A; + Constant *C, *C1; + if (match(Op0, m_Constant(C)) && + match(Op1, m_NUWAdd(m_Value(A), m_Constant(C1)))) { + Value *NewC = Builder.CreateBinOp(I.getOpcode(), C, C1); + return BinaryOperator::Create(I.getOpcode(), NewC, A); + } + + unsigned BitWidth = Ty->getScalarSizeInBits(); + + const APInt *AC, *AddC; + // Try to pre-shift a constant shifted by a variable amount added with a + // negative number: + // C << (X - AddC) --> (C >> AddC) << X + // and + // C >> (X - AddC) --> (C << AddC) >> X + if (match(Op0, m_APInt(AC)) && match(Op1, m_Add(m_Value(A), m_APInt(AddC))) && + AddC->isNegative() && (-*AddC).ult(BitWidth)) { + assert(!AC->isZero() && "Expected simplify of shifted zero"); + unsigned PosOffset = (-*AddC).getZExtValue(); + + auto isSuitableForPreShift = [PosOffset, &I, AC]() { + switch (I.getOpcode()) { + default: + return false; + case Instruction::Shl: + return (I.hasNoSignedWrap() || I.hasNoUnsignedWrap()) && + AC->eq(AC->lshr(PosOffset).shl(PosOffset)); + case Instruction::LShr: + return I.isExact() && AC->eq(AC->shl(PosOffset).lshr(PosOffset)); + case Instruction::AShr: + return I.isExact() && AC->eq(AC->shl(PosOffset).ashr(PosOffset)); + } + }; + if (isSuitableForPreShift()) { + Constant *NewC = ConstantInt::get(Ty, I.getOpcode() == Instruction::Shl + ? AC->lshr(PosOffset) + : AC->shl(PosOffset)); + BinaryOperator *NewShiftOp = + BinaryOperator::Create(I.getOpcode(), NewC, A); + if (I.getOpcode() == Instruction::Shl) { + NewShiftOp->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + } else { + NewShiftOp->setIsExact(); + } + return NewShiftOp; + } + } + + // X shift (A srem C) -> X shift (A and (C - 1)) iff C is a power of 2. + // Because shifts by negative values (which could occur if A were negative) + // are undefined. + if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Constant(C))) && + match(C, m_Power2())) { + // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't + // demand the sign bit (and many others) here?? + Constant *Mask = ConstantExpr::getSub(C, ConstantInt::get(Ty, 1)); + Value *Rem = Builder.CreateAnd(A, Mask, Op1->getName()); + return replaceOperand(I, 1, Rem); + } + + if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder)) + return Logic; + + return nullptr; +} + +/// Return true if we can simplify two logical (either left or right) shifts +/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2. +static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, + Instruction *InnerShift, + InstCombinerImpl &IC, Instruction *CxtI) { + assert(InnerShift->isLogicalShift() && "Unexpected instruction type"); + + // We need constant scalar or constant splat shifts. + const APInt *InnerShiftConst; + if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst))) + return false; + + // Two logical shifts in the same direction: + // shl (shl X, C1), C2 --> shl X, C1 + C2 + // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 + bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; + if (IsInnerShl == IsOuterShl) + return true; + + // Equal shift amounts in opposite directions become bitwise 'and': + // lshr (shl X, C), C --> and X, C' + // shl (lshr X, C), C --> and X, C' + if (*InnerShiftConst == OuterShAmt) + return true; + + // If the 2nd shift is bigger than the 1st, we can fold: + // lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3 + // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3 + // but it isn't profitable unless we know the and'd out bits are already zero. + // Also, check that the inner shift is valid (less than the type width) or + // we'll crash trying to produce the bit mask for the 'and'. + unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits(); + if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) { + unsigned InnerShAmt = InnerShiftConst->getZExtValue(); + unsigned MaskShift = + IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt; + APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift; + if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI)) + return true; + } + + return false; +} + +/// See if we can compute the specified value, but shifted logically to the left +/// or right by some number of bits. This should return true if the expression +/// can be computed for the same cost as the current expression tree. This is +/// used to eliminate extraneous shifting from things like: +/// %C = shl i128 %A, 64 +/// %D = shl i128 %B, 96 +/// %E = or i128 %C, %D +/// %F = lshr i128 %E, 64 +/// where the client will ask if E can be computed shifted right by 64-bits. If +/// this succeeds, getShiftedValue() will be called to produce the value. +static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, + InstCombinerImpl &IC, Instruction *CxtI) { + // We can always evaluate constants shifted. + if (isa<Constant>(V)) + return true; + + Instruction *I = dyn_cast<Instruction>(V); + if (!I) return false; + + // We can't mutate something that has multiple uses: doing so would + // require duplicating the instruction in general, which isn't profitable. + if (!I->hasOneUse()) return false; + + switch (I->getOpcode()) { + default: return false; + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. + return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && + canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I); + + case Instruction::Shl: + case Instruction::LShr: + return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI); + + case Instruction::Select: { + SelectInst *SI = cast<SelectInst>(I); + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && + canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI); + } + case Instruction::PHI: { + // We can change a phi if we can change all operands. Note that we never + // get into trouble with cyclic PHIs here because we only consider + // instructions with a single use. + PHINode *PN = cast<PHINode>(I); + for (Value *IncValue : PN->incoming_values()) + if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN)) + return false; + return true; + } + } +} + +/// Fold OuterShift (InnerShift X, C1), C2. +/// See canEvaluateShiftedShift() for the constraints on these instructions. +static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt, + bool IsOuterShl, + InstCombiner::BuilderTy &Builder) { + bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; + Type *ShType = InnerShift->getType(); + unsigned TypeWidth = ShType->getScalarSizeInBits(); + + // We only accept shifts-by-a-constant in canEvaluateShifted(). + const APInt *C1; + match(InnerShift->getOperand(1), m_APInt(C1)); + unsigned InnerShAmt = C1->getZExtValue(); + + // Change the shift amount and clear the appropriate IR flags. + auto NewInnerShift = [&](unsigned ShAmt) { + InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt)); + if (IsInnerShl) { + InnerShift->setHasNoUnsignedWrap(false); + InnerShift->setHasNoSignedWrap(false); + } else { + InnerShift->setIsExact(false); + } + return InnerShift; + }; + + // Two logical shifts in the same direction: + // shl (shl X, C1), C2 --> shl X, C1 + C2 + // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 + if (IsInnerShl == IsOuterShl) { + // If this is an oversized composite shift, then unsigned shifts get 0. + if (InnerShAmt + OuterShAmt >= TypeWidth) + return Constant::getNullValue(ShType); + + return NewInnerShift(InnerShAmt + OuterShAmt); + } + + // Equal shift amounts in opposite directions become bitwise 'and': + // lshr (shl X, C), C --> and X, C' + // shl (lshr X, C), C --> and X, C' + if (InnerShAmt == OuterShAmt) { + APInt Mask = IsInnerShl + ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt) + : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt); + Value *And = Builder.CreateAnd(InnerShift->getOperand(0), + ConstantInt::get(ShType, Mask)); + if (auto *AndI = dyn_cast<Instruction>(And)) { + AndI->moveBefore(InnerShift); + AndI->takeName(InnerShift); + } + return And; + } + + assert(InnerShAmt > OuterShAmt && + "Unexpected opposite direction logical shift pair"); + + // In general, we would need an 'and' for this transform, but + // canEvaluateShiftedShift() guarantees that the masked-off bits are not used. + // lshr (shl X, C1), C2 --> shl X, C1 - C2 + // shl (lshr X, C1), C2 --> lshr X, C1 - C2 + return NewInnerShift(InnerShAmt - OuterShAmt); +} + +/// When canEvaluateShifted() returns true for an expression, this function +/// inserts the new computation that produces the shifted value. +static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, + InstCombinerImpl &IC, const DataLayout &DL) { + // We can always evaluate constants shifted. + if (Constant *C = dyn_cast<Constant>(V)) { + if (isLeftShift) + return IC.Builder.CreateShl(C, NumBits); + else + return IC.Builder.CreateLShr(C, NumBits); + } + + Instruction *I = cast<Instruction>(V); + IC.addToWorklist(I); + + switch (I->getOpcode()) { + default: llvm_unreachable("Inconsistency with CanEvaluateShifted"); + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. + I->setOperand( + 0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL)); + I->setOperand( + 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); + return I; + + case Instruction::Shl: + case Instruction::LShr: + return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift, + IC.Builder); + + case Instruction::Select: + I->setOperand( + 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); + I->setOperand( + 2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL)); + return I; + case Instruction::PHI: { + // We can change a phi if we can change all operands. Note that we never + // get into trouble with cyclic PHIs here because we only consider + // instructions with a single use. + PHINode *PN = cast<PHINode>(I); + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits, + isLeftShift, IC, DL)); + return PN; + } + } +} + +// If this is a bitwise operator or add with a constant RHS we might be able +// to pull it through a shift. +static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift, + BinaryOperator *BO) { + switch (BO->getOpcode()) { + default: + return false; // Do not perform transform! + case Instruction::Add: + return Shift.getOpcode() == Instruction::Shl; + case Instruction::Or: + case Instruction::And: + return true; + case Instruction::Xor: + // Do not change a 'not' of logical shift because that would create a normal + // 'xor'. The 'not' is likely better for analysis, SCEV, and codegen. + return !(Shift.isLogicalShift() && match(BO, m_Not(m_Value()))); + } +} + +Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, + BinaryOperator &I) { + // (C2 << X) << C1 --> (C2 << C1) << X + // (C2 >> X) >> C1 --> (C2 >> C1) >> X + Constant *C2; + Value *X; + if (match(Op0, m_BinOp(I.getOpcode(), m_Constant(C2), m_Value(X)))) + return BinaryOperator::Create( + I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X); + + const APInt *Op1C; + if (!match(C1, m_APInt(Op1C))) + return nullptr; + + // See if we can propagate this shift into the input, this covers the trivial + // cast of lshr(shl(x,c1),c2) as well as other more complex cases. + bool IsLeftShift = I.getOpcode() == Instruction::Shl; + if (I.getOpcode() != Instruction::AShr && + canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) { + LLVM_DEBUG( + dbgs() << "ICE: GetShiftedValue propagating shift through expression" + " to eliminate shift:\n IN: " + << *Op0 << "\n SH: " << I << "\n"); + + return replaceInstUsesWith( + I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL)); + } + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + Type *Ty = I.getType(); + unsigned TypeBits = Ty->getScalarSizeInBits(); + assert(!Op1C->uge(TypeBits) && + "Shift over the type width should have been removed already"); + (void)TypeBits; + + if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I)) + return FoldedShift; + + if (!Op0->hasOneUse()) + return nullptr; + + if (auto *Op0BO = dyn_cast<BinaryOperator>(Op0)) { + // If the operand is a bitwise operator with a constant RHS, and the + // shift is the only use, we can pull it out of the shift. + const APInt *Op0C; + if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { + if (canShiftBinOpWithConstantRHS(I, Op0BO)) { + Value *NewRHS = + Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(1), C1); + + Value *NewShift = + Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), C1); + NewShift->takeName(Op0BO); + + return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, NewRHS); + } + } + } + + // If we have a select that conditionally executes some binary operator, + // see if we can pull it the select and operator through the shift. + // + // For example, turning: + // shl (select C, (add X, C1), X), C2 + // Into: + // Y = shl X, C2 + // select C, (add Y, C1 << C2), Y + Value *Cond; + BinaryOperator *TBO; + Value *FalseVal; + if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)), + m_Value(FalseVal)))) { + const APInt *C; + if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && + match(TBO->getOperand(1), m_APInt(C)) && + canShiftBinOpWithConstantRHS(I, TBO)) { + Value *NewRHS = + Builder.CreateBinOp(I.getOpcode(), TBO->getOperand(1), C1); + + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), FalseVal, C1); + Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, NewRHS); + return SelectInst::Create(Cond, NewOp, NewShift); + } + } + + BinaryOperator *FBO; + Value *TrueVal; + if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal), + m_OneUse(m_BinOp(FBO))))) { + const APInt *C; + if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && + match(FBO->getOperand(1), m_APInt(C)) && + canShiftBinOpWithConstantRHS(I, FBO)) { + Value *NewRHS = + Builder.CreateBinOp(I.getOpcode(), FBO->getOperand(1), C1); + + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), TrueVal, C1); + Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, NewRHS); + return SelectInst::Create(Cond, NewShift, NewOp); + } + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { + const SimplifyQuery Q = SQ.getWithInstruction(&I); + + if (Value *V = simplifyShlInst(I.getOperand(0), I.getOperand(1), + I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), Q)) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *V = commonShiftTransforms(I)) + return V; + + if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, Q, Builder)) + return V; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + + const APInt *C; + if (match(Op1, m_APInt(C))) { + unsigned ShAmtC = C->getZExtValue(); + + // shl (zext X), C --> zext (shl X, C) + // This is only valid if X would have zeros shifted out. + Value *X; + if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) { + unsigned SrcWidth = X->getType()->getScalarSizeInBits(); + if (ShAmtC < SrcWidth && + MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), 0, &I)) + return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty); + } + + // (X >> C) << C --> X & (-1 << C) + if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) { + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); + return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); + } + + const APInt *C1; + if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(C1)))) && + C1->ult(BitWidth)) { + unsigned ShrAmt = C1->getZExtValue(); + if (ShrAmt < ShAmtC) { + // If C1 < C: (X >>?,exact C1) << C --> X << (C - C1) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); + auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); + NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); + return NewShl; + } + if (ShrAmt > ShAmtC) { + // If C1 > C: (X >>?exact C1) << C --> X >>?exact (C1 - C) + Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmtC); + auto *NewShr = BinaryOperator::Create( + cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff); + NewShr->setIsExact(true); + return NewShr; + } + } + + if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_APInt(C1)))) && + C1->ult(BitWidth)) { + unsigned ShrAmt = C1->getZExtValue(); + if (ShrAmt < ShAmtC) { + // If C1 < C: (X >>? C1) << C --> (X << (C - C1)) & (-1 << C) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); + auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); + NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); + Builder.Insert(NewShl); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); + return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); + } + if (ShrAmt > ShAmtC) { + // If C1 > C: (X >>? C1) << C --> (X >>? (C1 - C)) & (-1 << C) + Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmtC); + auto *OldShr = cast<BinaryOperator>(Op0); + auto *NewShr = + BinaryOperator::Create(OldShr->getOpcode(), X, ShiftDiff); + NewShr->setIsExact(OldShr->isExact()); + Builder.Insert(NewShr); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); + return BinaryOperator::CreateAnd(NewShr, ConstantInt::get(Ty, Mask)); + } + } + + // Similar to above, but look through an intermediate trunc instruction. + BinaryOperator *Shr; + if (match(Op0, m_OneUse(m_Trunc(m_OneUse(m_BinOp(Shr))))) && + match(Shr, m_Shr(m_Value(X), m_APInt(C1)))) { + // The larger shift direction survives through the transform. + unsigned ShrAmtC = C1->getZExtValue(); + unsigned ShDiff = ShrAmtC > ShAmtC ? ShrAmtC - ShAmtC : ShAmtC - ShrAmtC; + Constant *ShiftDiffC = ConstantInt::get(X->getType(), ShDiff); + auto ShiftOpc = ShrAmtC > ShAmtC ? Shr->getOpcode() : Instruction::Shl; + + // If C1 > C: + // (trunc (X >> C1)) << C --> (trunc (X >> (C1 - C))) && (-1 << C) + // If C > C1: + // (trunc (X >> C1)) << C --> (trunc (X << (C - C1))) && (-1 << C) + Value *NewShift = Builder.CreateBinOp(ShiftOpc, X, ShiftDiffC, "sh.diff"); + Value *Trunc = Builder.CreateTrunc(NewShift, Ty, "tr.sh.diff"); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); + return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, Mask)); + } + + if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) { + unsigned AmtSum = ShAmtC + C1->getZExtValue(); + // Oversized shifts are simplified to zero in InstSimplify. + if (AmtSum < BitWidth) + // (X << C1) << C2 --> X << (C1 + C2) + return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum)); + } + + // If we have an opposite shift by the same amount, we may be able to + // reorder binops and shifts to eliminate math/logic. + auto isSuitableBinOpcode = [](Instruction::BinaryOps BinOpcode) { + switch (BinOpcode) { + default: + return false; + case Instruction::Add: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Sub: + // NOTE: Sub is not commutable and the tranforms below may not be valid + // when the shift-right is operand 1 (RHS) of the sub. + return true; + } + }; + BinaryOperator *Op0BO; + if (match(Op0, m_OneUse(m_BinOp(Op0BO))) && + isSuitableBinOpcode(Op0BO->getOpcode())) { + // Commute so shift-right is on LHS of the binop. + // (Y bop (X >> C)) << C -> ((X >> C) bop Y) << C + // (Y bop ((X >> C) & CC)) << C -> (((X >> C) & CC) bop Y) << C + Value *Shr = Op0BO->getOperand(0); + Value *Y = Op0BO->getOperand(1); + Value *X; + const APInt *CC; + if (Op0BO->isCommutative() && Y->hasOneUse() && + (match(Y, m_Shr(m_Value(), m_Specific(Op1))) || + match(Y, m_And(m_OneUse(m_Shr(m_Value(), m_Specific(Op1))), + m_APInt(CC))))) + std::swap(Shr, Y); + + // ((X >> C) bop Y) << C -> (X bop (Y << C)) & (~0 << C) + if (match(Shr, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) { + // Y << C + Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName()); + // (X bop (Y << C)) + Value *B = + Builder.CreateBinOp(Op0BO->getOpcode(), X, YS, Shr->getName()); + unsigned Op1Val = C->getLimitedValue(BitWidth); + APInt Bits = APInt::getHighBitsSet(BitWidth, BitWidth - Op1Val); + Constant *Mask = ConstantInt::get(Ty, Bits); + return BinaryOperator::CreateAnd(B, Mask); + } + + // (((X >> C) & CC) bop Y) << C -> (X & (CC << C)) bop (Y << C) + if (match(Shr, + m_OneUse(m_And(m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))), + m_APInt(CC))))) { + // Y << C + Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName()); + // X & (CC << C) + Value *M = Builder.CreateAnd(X, ConstantInt::get(Ty, CC->shl(*C)), + X->getName() + ".mask"); + return BinaryOperator::Create(Op0BO->getOpcode(), M, YS); + } + } + + // (C1 - X) << C --> (C1 << C) - (X << C) + if (match(Op0, m_OneUse(m_Sub(m_APInt(C1), m_Value(X))))) { + Constant *NewLHS = ConstantInt::get(Ty, C1->shl(*C)); + Value *NewShift = Builder.CreateShl(X, Op1); + return BinaryOperator::CreateSub(NewLHS, NewShift); + } + + // If the shifted-out value is known-zero, then this is a NUW shift. + if (!I.hasNoUnsignedWrap() && + MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmtC), 0, + &I)) { + I.setHasNoUnsignedWrap(); + return &I; + } + + // If the shifted-out value is all signbits, then this is a NSW shift. + if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmtC) { + I.setHasNoSignedWrap(); + return &I; + } + } + + // Transform (x >> y) << y to x & (-1 << y) + // Valid for any type of right-shift. + Value *X; + if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); + Value *Mask = Builder.CreateShl(AllOnes, Op1); + return BinaryOperator::CreateAnd(Mask, X); + } + + Constant *C1; + if (match(Op1, m_Constant(C1))) { + Constant *C2; + Value *X; + // (X * C2) << C1 --> X * (C2 << C1) + if (match(Op0, m_Mul(m_Value(X), m_Constant(C2)))) + return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); + + // shl (zext i1 X), C1 --> select (X, 1 << C1, 0) + if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + auto *NewC = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C1); + return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty)); + } + } + + // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 + if (match(Op0, m_One()) && + match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X)))) + return BinaryOperator::CreateLShr( + ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); + + return nullptr; +} + +Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { + if (Value *V = simplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *R = commonShiftTransforms(I)) + return R; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); + const APInt *C; + if (match(Op1, m_APInt(C))) { + unsigned ShAmtC = C->getZExtValue(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + auto *II = dyn_cast<IntrinsicInst>(Op0); + if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmtC && + (II->getIntrinsicID() == Intrinsic::ctlz || + II->getIntrinsicID() == Intrinsic::cttz || + II->getIntrinsicID() == Intrinsic::ctpop)) { + // ctlz.i32(x)>>5 --> zext(x == 0) + // cttz.i32(x)>>5 --> zext(x == 0) + // ctpop.i32(x)>>5 --> zext(x == -1) + bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop; + Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0); + Value *Cmp = Builder.CreateICmpEQ(II->getArgOperand(0), RHS); + return new ZExtInst(Cmp, Ty); + } + + Value *X; + const APInt *C1; + if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) { + if (C1->ult(ShAmtC)) { + unsigned ShlAmtC = C1->getZExtValue(); + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShlAmtC); + if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { + // (X <<nuw C1) >>u C --> X >>u (C - C1) + auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff); + NewLShr->setIsExact(I.isExact()); + return NewLShr; + } + if (Op0->hasOneUse()) { + // (X << C1) >>u C --> (X >>u (C - C1)) & (-1 >> C) + Value *NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact()); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); + return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask)); + } + } else if (C1->ugt(ShAmtC)) { + unsigned ShlAmtC = C1->getZExtValue(); + Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmtC - ShAmtC); + if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { + // (X <<nuw C1) >>u C --> X <<nuw (C1 - C) + auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); + NewShl->setHasNoUnsignedWrap(true); + return NewShl; + } + if (Op0->hasOneUse()) { + // (X << C1) >>u C --> X << (C1 - C) & (-1 >> C) + Value *NewShl = Builder.CreateShl(X, ShiftDiff); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); + return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); + } + } else { + assert(*C1 == ShAmtC); + // (X << C) >>u C --> X & (-1 >>u C) + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); + return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); + } + } + + // ((X << C) + Y) >>u C --> (X + (Y >>u C)) & (-1 >>u C) + // TODO: Consolidate with the more general transform that starts from shl + // (the shifts are in the opposite order). + Value *Y; + if (match(Op0, + m_OneUse(m_c_Add(m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))), + m_Value(Y))))) { + Value *NewLshr = Builder.CreateLShr(Y, Op1); + Value *NewAdd = Builder.CreateAdd(NewLshr, X); + unsigned Op1Val = C->getLimitedValue(BitWidth); + APInt Bits = APInt::getLowBitsSet(BitWidth, BitWidth - Op1Val); + Constant *Mask = ConstantInt::get(Ty, Bits); + return BinaryOperator::CreateAnd(NewAdd, Mask); + } + + if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) && + (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { + assert(ShAmtC < X->getType()->getScalarSizeInBits() && + "Big shift not simplified to zero?"); + // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN + Value *NewLShr = Builder.CreateLShr(X, ShAmtC); + return new ZExtInst(NewLShr, Ty); + } + + if (match(Op0, m_SExt(m_Value(X)))) { + unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits(); + // lshr (sext i1 X to iN), C --> select (X, -1 >> C, 0) + if (SrcTyBitWidth == 1) { + auto *NewC = ConstantInt::get( + Ty, APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); + return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty)); + } + + if ((!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType())) && + Op0->hasOneUse()) { + // Are we moving the sign bit to the low bit and widening with high + // zeros? lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN + if (ShAmtC == BitWidth - 1) { + Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1); + return new ZExtInst(NewLShr, Ty); + } + + // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN + if (ShAmtC == BitWidth - SrcTyBitWidth) { + // The new shift amount can't be more than the narrow source type. + unsigned NewShAmt = std::min(ShAmtC, SrcTyBitWidth - 1); + Value *AShr = Builder.CreateAShr(X, NewShAmt); + return new ZExtInst(AShr, Ty); + } + } + } + + if (ShAmtC == BitWidth - 1) { + // lshr i32 or(X,-X), 31 --> zext (X != 0) + if (match(Op0, m_OneUse(m_c_Or(m_Neg(m_Value(X)), m_Deferred(X))))) + return new ZExtInst(Builder.CreateIsNotNull(X), Ty); + + // lshr i32 (X -nsw Y), 31 --> zext (X < Y) + if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y))))) + return new ZExtInst(Builder.CreateICmpSLT(X, Y), Ty); + + // Check if a number is negative and odd: + // lshr i32 (srem X, 2), 31 --> and (X >> 31), X + if (match(Op0, m_OneUse(m_SRem(m_Value(X), m_SpecificInt(2))))) { + Value *Signbit = Builder.CreateLShr(X, ShAmtC); + return BinaryOperator::CreateAnd(Signbit, X); + } + } + + // (X >>u C1) >>u C --> X >>u (C1 + C) + if (match(Op0, m_LShr(m_Value(X), m_APInt(C1)))) { + // Oversized shifts are simplified to zero in InstSimplify. + unsigned AmtSum = ShAmtC + C1->getZExtValue(); + if (AmtSum < BitWidth) + return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum)); + } + + Instruction *TruncSrc; + if (match(Op0, m_OneUse(m_Trunc(m_Instruction(TruncSrc)))) && + match(TruncSrc, m_LShr(m_Value(X), m_APInt(C1)))) { + unsigned SrcWidth = X->getType()->getScalarSizeInBits(); + unsigned AmtSum = ShAmtC + C1->getZExtValue(); + + // If the combined shift fits in the source width: + // (trunc (X >>u C1)) >>u C --> and (trunc (X >>u (C1 + C)), MaskC + // + // If the first shift covers the number of bits truncated, then the + // mask instruction is eliminated (and so the use check is relaxed). + if (AmtSum < SrcWidth && + (TruncSrc->hasOneUse() || C1->uge(SrcWidth - BitWidth))) { + Value *SumShift = Builder.CreateLShr(X, AmtSum, "sum.shift"); + Value *Trunc = Builder.CreateTrunc(SumShift, Ty, I.getName()); + + // If the first shift does not cover the number of bits truncated, then + // we require a mask to get rid of high bits in the result. + APInt MaskC = APInt::getAllOnes(BitWidth).lshr(ShAmtC); + return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, MaskC)); + } + } + + const APInt *MulC; + if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC)))) { + // Look for a "splat" mul pattern - it replicates bits across each half of + // a value, so a right shift is just a mask of the low bits: + // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1 + // TODO: Generalize to allow more than just half-width shifts? + if (BitWidth > 2 && ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() && + MulC->logBase2() == ShAmtC) + return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2)); + + // The one-use check is not strictly necessary, but codegen may not be + // able to invert the transform and perf may suffer with an extra mul + // instruction. + if (Op0->hasOneUse()) { + APInt NewMulC = MulC->lshr(ShAmtC); + // if c is divisible by (1 << ShAmtC): + // lshr (mul nuw x, MulC), ShAmtC -> mul nuw x, (MulC >> ShAmtC) + if (MulC->eq(NewMulC.shl(ShAmtC))) { + auto *NewMul = + BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC)); + BinaryOperator *OrigMul = cast<BinaryOperator>(Op0); + NewMul->setHasNoSignedWrap(OrigMul->hasNoSignedWrap()); + return NewMul; + } + } + } + + // Try to narrow bswap. + // In the case where the shift amount equals the bitwidth difference, the + // shift is eliminated. + if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::bswap>( + m_OneUse(m_ZExt(m_Value(X))))))) { + unsigned SrcWidth = X->getType()->getScalarSizeInBits(); + unsigned WidthDiff = BitWidth - SrcWidth; + if (SrcWidth % 16 == 0) { + Value *NarrowSwap = Builder.CreateUnaryIntrinsic(Intrinsic::bswap, X); + if (ShAmtC >= WidthDiff) { + // (bswap (zext X)) >> C --> zext (bswap X >> C') + Value *NewShift = Builder.CreateLShr(NarrowSwap, ShAmtC - WidthDiff); + return new ZExtInst(NewShift, Ty); + } else { + // (bswap (zext X)) >> C --> (zext (bswap X)) << C' + Value *NewZExt = Builder.CreateZExt(NarrowSwap, Ty); + Constant *ShiftDiff = ConstantInt::get(Ty, WidthDiff - ShAmtC); + return BinaryOperator::CreateShl(NewZExt, ShiftDiff); + } + } + } + + // If the shifted-out value is known-zero, then this is an exact shift. + if (!I.isExact() && + MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) { + I.setIsExact(); + return &I; + } + } + + // Transform (x << y) >> y to x & (-1 >> y) + Value *X; + if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); + Value *Mask = Builder.CreateLShr(AllOnes, Op1); + return BinaryOperator::CreateAnd(Mask, X); + } + + return nullptr; +} + +Instruction * +InstCombinerImpl::foldVariableSignZeroExtensionOfVariableHighBitExtract( + BinaryOperator &OldAShr) { + assert(OldAShr.getOpcode() == Instruction::AShr && + "Must be called with arithmetic right-shift instruction only."); + + // Check that constant C is a splat of the element-wise bitwidth of V. + auto BitWidthSplat = [](Constant *C, Value *V) { + return match( + C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(C->getType()->getScalarSizeInBits(), + V->getType()->getScalarSizeInBits()))); + }; + + // It should look like variable-length sign-extension on the outside: + // (Val << (bitwidth(Val)-Nbits)) a>> (bitwidth(Val)-Nbits) + Value *NBits; + Instruction *MaybeTrunc; + Constant *C1, *C2; + if (!match(&OldAShr, + m_AShr(m_Shl(m_Instruction(MaybeTrunc), + m_ZExtOrSelf(m_Sub(m_Constant(C1), + m_ZExtOrSelf(m_Value(NBits))))), + m_ZExtOrSelf(m_Sub(m_Constant(C2), + m_ZExtOrSelf(m_Deferred(NBits)))))) || + !BitWidthSplat(C1, &OldAShr) || !BitWidthSplat(C2, &OldAShr)) + return nullptr; + + // There may or may not be a truncation after outer two shifts. + Instruction *HighBitExtract; + match(MaybeTrunc, m_TruncOrSelf(m_Instruction(HighBitExtract))); + bool HadTrunc = MaybeTrunc != HighBitExtract; + + // And finally, the innermost part of the pattern must be a right-shift. + Value *X, *NumLowBitsToSkip; + if (!match(HighBitExtract, m_Shr(m_Value(X), m_Value(NumLowBitsToSkip)))) + return nullptr; + + // Said right-shift must extract high NBits bits - C0 must be it's bitwidth. + Constant *C0; + if (!match(NumLowBitsToSkip, + m_ZExtOrSelf( + m_Sub(m_Constant(C0), m_ZExtOrSelf(m_Specific(NBits))))) || + !BitWidthSplat(C0, HighBitExtract)) + return nullptr; + + // Since the NBits is identical for all shifts, if the outermost and + // innermost shifts are identical, then outermost shifts are redundant. + // If we had truncation, do keep it though. + if (HighBitExtract->getOpcode() == OldAShr.getOpcode()) + return replaceInstUsesWith(OldAShr, MaybeTrunc); + + // Else, if there was a truncation, then we need to ensure that one + // instruction will go away. + if (HadTrunc && !match(&OldAShr, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + return nullptr; + + // Finally, bypass two innermost shifts, and perform the outermost shift on + // the operands of the innermost shift. + Instruction *NewAShr = + BinaryOperator::Create(OldAShr.getOpcode(), X, NumLowBitsToSkip); + NewAShr->copyIRFlags(HighBitExtract); // We can preserve 'exact'-ness. + if (!HadTrunc) + return NewAShr; + + Builder.Insert(NewAShr); + return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType()); +} + +Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { + if (Value *V = simplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *R = commonShiftTransforms(I)) + return R; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + const APInt *ShAmtAPInt; + if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) { + unsigned ShAmt = ShAmtAPInt->getZExtValue(); + + // If the shift amount equals the difference in width of the destination + // and source scalar types: + // ashr (shl (zext X), C), C --> sext X + Value *X; + if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) && + ShAmt == BitWidth - X->getType()->getScalarSizeInBits()) + return new SExtInst(X, Ty); + + // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However, + // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. + const APInt *ShOp1; + if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) && + ShOp1->ult(BitWidth)) { + unsigned ShlAmt = ShOp1->getZExtValue(); + if (ShlAmt < ShAmt) { + // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); + auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff); + NewAShr->setIsExact(I.isExact()); + return NewAShr; + } + if (ShlAmt > ShAmt) { + // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2) + Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); + auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff); + NewShl->setHasNoSignedWrap(true); + return NewShl; + } + } + + if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) && + ShOp1->ult(BitWidth)) { + unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // Oversized arithmetic shifts replicate the sign bit. + AmtSum = std::min(AmtSum, BitWidth - 1); + // (X >>s C1) >>s C2 --> X >>s (C1 + C2) + return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum)); + } + + if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) && + (Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) { + // ashr (sext X), C --> sext (ashr X, C') + Type *SrcTy = X->getType(); + ShAmt = std::min(ShAmt, SrcTy->getScalarSizeInBits() - 1); + Value *NewSh = Builder.CreateAShr(X, ConstantInt::get(SrcTy, ShAmt)); + return new SExtInst(NewSh, Ty); + } + + if (ShAmt == BitWidth - 1) { + // ashr i32 or(X,-X), 31 --> sext (X != 0) + if (match(Op0, m_OneUse(m_c_Or(m_Neg(m_Value(X)), m_Deferred(X))))) + return new SExtInst(Builder.CreateIsNotNull(X), Ty); + + // ashr i32 (X -nsw Y), 31 --> sext (X < Y) + Value *Y; + if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y))))) + return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty); + } + + // If the shifted-out value is known-zero, then this is an exact shift. + if (!I.isExact() && + MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { + I.setIsExact(); + return &I; + } + } + + // Prefer `-(x & 1)` over `(x << (bitwidth(x)-1)) a>> (bitwidth(x)-1)` + // as the pattern to splat the lowest bit. + // FIXME: iff X is already masked, we don't need the one-use check. + Value *X; + if (match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)) && + match(Op0, m_OneUse(m_Shl(m_Value(X), + m_SpecificIntAllowUndef(BitWidth - 1))))) { + Constant *Mask = ConstantInt::get(Ty, 1); + // Retain the knowledge about the ignored lanes. + Mask = Constant::mergeUndefsWith( + Constant::mergeUndefsWith(Mask, cast<Constant>(Op1)), + cast<Constant>(cast<Instruction>(Op0)->getOperand(1))); + X = Builder.CreateAnd(X, Mask); + return BinaryOperator::CreateNeg(X); + } + + if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I)) + return R; + + // See if we can turn a signed shr into an unsigned shr. + if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) + return BinaryOperator::CreateLShr(Op0, Op1); + + // ashr (xor %x, -1), %y --> xor (ashr %x, %y), -1 + if (match(Op0, m_OneUse(m_Not(m_Value(X))))) { + // Note that we must drop 'exact'-ness of the shift! + // Note that we can't keep undef's in -1 vector constant! + auto *NewAShr = Builder.CreateAShr(X, Op1, Op0->getName() + ".not"); + return BinaryOperator::CreateNot(NewAShr); + } + + return nullptr; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp new file mode 100644 index 000000000000..9d4c01ac03e2 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -0,0 +1,1665 @@ +//===- InstCombineSimplifyDemanded.cpp ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains logic for simplifying instructions based on information +// about how they are used. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "instcombine" + +/// Check to see if the specified operand of the specified instruction is a +/// constant integer. If so, check to see if there are any bits set in the +/// constant that are not demanded. If so, shrink the constant and return true. +static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, + const APInt &Demanded) { + assert(I && "No instruction?"); + assert(OpNo < I->getNumOperands() && "Operand index too large"); + + // The operand must be a constant integer or splat integer. + Value *Op = I->getOperand(OpNo); + const APInt *C; + if (!match(Op, m_APInt(C))) + return false; + + // If there are no bits set that aren't demanded, nothing to do. + if (C->isSubsetOf(Demanded)) + return false; + + // This instruction is producing bits that are not demanded. Shrink the RHS. + I->setOperand(OpNo, ConstantInt::get(Op->getType(), *C & Demanded)); + + return true; +} + + + +/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if +/// the instruction has any properties that allow us to simplify its operands. +bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { + unsigned BitWidth = Inst.getType()->getScalarSizeInBits(); + KnownBits Known(BitWidth); + APInt DemandedMask(APInt::getAllOnes(BitWidth)); + + Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known, + 0, &Inst); + if (!V) return false; + if (V == &Inst) return true; + replaceInstUsesWith(Inst, V); + return true; +} + +/// This form of SimplifyDemandedBits simplifies the specified instruction +/// operand if possible, updating it in place. It returns true if it made any +/// change and false otherwise. +bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo, + const APInt &DemandedMask, + KnownBits &Known, unsigned Depth) { + Use &U = I->getOperandUse(OpNo); + Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, Known, + Depth, I); + if (!NewVal) return false; + if (Instruction* OpInst = dyn_cast<Instruction>(U)) + salvageDebugInfo(*OpInst); + + replaceUse(U, NewVal); + return true; +} + +/// This function attempts to replace V with a simpler value based on the +/// demanded bits. When this function is called, it is known that only the bits +/// set in DemandedMask of the result of V are ever used downstream. +/// Consequently, depending on the mask and V, it may be possible to replace V +/// with a constant or one of its operands. In such cases, this function does +/// the replacement and returns true. In all other cases, it returns false after +/// analyzing the expression and setting KnownOne and known to be one in the +/// expression. Known.Zero contains all the bits that are known to be zero in +/// the expression. These are provided to potentially allow the caller (which +/// might recursively be SimplifyDemandedBits itself) to simplify the +/// expression. +/// Known.One and Known.Zero always follow the invariant that: +/// Known.One & Known.Zero == 0. +/// That is, a bit can't be both 1 and 0. Note that the bits in Known.One and +/// Known.Zero may only be accurate for those bits set in DemandedMask. Note +/// also that the bitwidth of V, DemandedMask, Known.Zero and Known.One must all +/// be the same. +/// +/// This returns null if it did not change anything and it permits no +/// simplification. This returns V itself if it did some simplification of V's +/// operands based on the information about what bits are demanded. This returns +/// some other non-null value if it found out that V is equal to another value +/// in the context where the specified bits are demanded, but not for all users. +Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, + KnownBits &Known, + unsigned Depth, + Instruction *CxtI) { + assert(V != nullptr && "Null pointer of Value???"); + assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); + uint32_t BitWidth = DemandedMask.getBitWidth(); + Type *VTy = V->getType(); + assert( + (!VTy->isIntOrIntVectorTy() || VTy->getScalarSizeInBits() == BitWidth) && + Known.getBitWidth() == BitWidth && + "Value *V, DemandedMask and Known must have same BitWidth"); + + if (isa<Constant>(V)) { + computeKnownBits(V, Known, Depth, CxtI); + return nullptr; + } + + Known.resetAll(); + if (DemandedMask.isZero()) // Not demanding any bits from V. + return UndefValue::get(VTy); + + if (Depth == MaxAnalysisRecursionDepth) + return nullptr; + + if (isa<ScalableVectorType>(VTy)) + return nullptr; + + Instruction *I = dyn_cast<Instruction>(V); + if (!I) { + computeKnownBits(V, Known, Depth, CxtI); + return nullptr; // Only analyze instructions. + } + + // If there are multiple uses of this value and we aren't at the root, then + // we can't do any simplifications of the operands, because DemandedMask + // only reflects the bits demanded by *one* of the users. + if (Depth != 0 && !I->hasOneUse()) + return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, CxtI); + + KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth); + + // If this is the root being simplified, allow it to have multiple uses, + // just set the DemandedMask to all bits so that we can try to simplify the + // operands. This allows visitTruncInst (for example) to simplify the + // operand of a trunc without duplicating all the logic below. + if (Depth == 0 && !V->hasOneUse()) + DemandedMask.setAllBits(); + + // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care + // about the high bits of the operands. + auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) { + unsigned NLZ = DemandedMask.countLeadingZeros(); + // Right fill the mask of bits for the operands to demand the most + // significant bit and all those below it. + DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); + if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || + SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || + ShrinkDemandedConstant(I, 1, DemandedFromOps) || + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { + if (NLZ > 0) { + // Disable the nsw and nuw flags here: We can no longer guarantee that + // we won't wrap after simplification. Removing the nsw/nuw flags is + // legal here because the top bit is not demanded. + I->setHasNoSignedWrap(false); + I->setHasNoUnsignedWrap(false); + } + return true; + } + return false; + }; + + switch (I->getOpcode()) { + default: + computeKnownBits(I, Known, Depth, CxtI); + break; + case Instruction::And: { + // If either the LHS or the RHS are Zero, the result is zero. + if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.Zero, LHSKnown, + Depth + 1)) + return I; + assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); + assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); + + Known = LHSKnown & RHSKnown; + + // If the client is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(VTy, Known.One); + + // If all of the demanded bits are known 1 on one side, return the other. + // These bits cannot contribute to the result of the 'and'. + if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One)) + return I->getOperand(0); + if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One)) + return I->getOperand(1); + + // If the RHS is a constant, see if we can simplify it. + if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnown.Zero)) + return I; + + break; + } + case Instruction::Or: { + // If either the LHS or the RHS are One, the result is One. + if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown, + Depth + 1)) + return I; + assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); + assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); + + Known = LHSKnown | RHSKnown; + + // If the client is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(VTy, Known.One); + + // If all of the demanded bits are known zero on one side, return the other. + // These bits cannot contribute to the result of the 'or'. + if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero)) + return I->getOperand(0); + if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) + return I->getOperand(1); + + // If the RHS is a constant, see if we can simplify it. + if (ShrinkDemandedConstant(I, 1, DemandedMask)) + return I; + + break; + } + case Instruction::Xor: { + if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1)) + return I; + Value *LHS, *RHS; + if (DemandedMask == 1 && + match(I->getOperand(0), m_Intrinsic<Intrinsic::ctpop>(m_Value(LHS))) && + match(I->getOperand(1), m_Intrinsic<Intrinsic::ctpop>(m_Value(RHS)))) { + // (ctpop(X) ^ ctpop(Y)) & 1 --> ctpop(X^Y) & 1 + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(I); + auto *Xor = Builder.CreateXor(LHS, RHS); + return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, Xor); + } + + assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); + assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); + + Known = LHSKnown ^ RHSKnown; + + // If the client is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(VTy, Known.One); + + // If all of the demanded bits are known zero on one side, return the other. + // These bits cannot contribute to the result of the 'xor'. + if (DemandedMask.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + if (DemandedMask.isSubsetOf(LHSKnown.Zero)) + return I->getOperand(1); + + // If all of the demanded bits are known to be zero on one side or the + // other, turn this into an *inclusive* or. + // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0 + if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.Zero)) { + Instruction *Or = + BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1), + I->getName()); + return InsertNewInstWith(Or, *I); + } + + // If all of the demanded bits on one side are known, and all of the set + // bits on that side are also known to be set on the other side, turn this + // into an AND, as we know the bits will be cleared. + // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2 + if (DemandedMask.isSubsetOf(RHSKnown.Zero|RHSKnown.One) && + RHSKnown.One.isSubsetOf(LHSKnown.One)) { + Constant *AndC = Constant::getIntegerValue(VTy, + ~RHSKnown.One & DemandedMask); + Instruction *And = BinaryOperator::CreateAnd(I->getOperand(0), AndC); + return InsertNewInstWith(And, *I); + } + + // If the RHS is a constant, see if we can change it. Don't alter a -1 + // constant because that's a canonical 'not' op, and that is better for + // combining, SCEV, and codegen. + const APInt *C; + if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) { + if ((*C | ~DemandedMask).isAllOnes()) { + // Force bits to 1 to create a 'not' op. + I->setOperand(1, ConstantInt::getAllOnesValue(VTy)); + return I; + } + // If we can't turn this into a 'not', try to shrink the constant. + if (ShrinkDemandedConstant(I, 1, DemandedMask)) + return I; + } + + // If our LHS is an 'and' and if it has one use, and if any of the bits we + // are flipping are known to be set, then the xor is just resetting those + // bits to zero. We can just knock out bits from the 'and' and the 'xor', + // simplifying both of them. + if (Instruction *LHSInst = dyn_cast<Instruction>(I->getOperand(0))) { + ConstantInt *AndRHS, *XorRHS; + if (LHSInst->getOpcode() == Instruction::And && LHSInst->hasOneUse() && + match(I->getOperand(1), m_ConstantInt(XorRHS)) && + match(LHSInst->getOperand(1), m_ConstantInt(AndRHS)) && + (LHSKnown.One & RHSKnown.One & DemandedMask) != 0) { + APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask); + + Constant *AndC = ConstantInt::get(VTy, NewMask & AndRHS->getValue()); + Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC); + InsertNewInstWith(NewAnd, *I); + + Constant *XorC = ConstantInt::get(VTy, NewMask & XorRHS->getValue()); + Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC); + return InsertNewInstWith(NewXor, *I); + } + } + break; + } + case Instruction::Select: { + if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1)) + return I; + assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); + assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); + + // If the operands are constants, see if we can simplify them. + // This is similar to ShrinkDemandedConstant, but for a select we want to + // try to keep the selected constants the same as icmp value constants, if + // we can. This helps not break apart (or helps put back together) + // canonical patterns like min and max. + auto CanonicalizeSelectConstant = [](Instruction *I, unsigned OpNo, + const APInt &DemandedMask) { + const APInt *SelC; + if (!match(I->getOperand(OpNo), m_APInt(SelC))) + return false; + + // Get the constant out of the ICmp, if there is one. + // Only try this when exactly 1 operand is a constant (if both operands + // are constant, the icmp should eventually simplify). Otherwise, we may + // invert the transform that reduces set bits and infinite-loop. + Value *X; + const APInt *CmpC; + ICmpInst::Predicate Pred; + if (!match(I->getOperand(0), m_ICmp(Pred, m_Value(X), m_APInt(CmpC))) || + isa<Constant>(X) || CmpC->getBitWidth() != SelC->getBitWidth()) + return ShrinkDemandedConstant(I, OpNo, DemandedMask); + + // If the constant is already the same as the ICmp, leave it as-is. + if (*CmpC == *SelC) + return false; + // If the constants are not already the same, but can be with the demand + // mask, use the constant value from the ICmp. + if ((*CmpC & DemandedMask) == (*SelC & DemandedMask)) { + I->setOperand(OpNo, ConstantInt::get(I->getType(), *CmpC)); + return true; + } + return ShrinkDemandedConstant(I, OpNo, DemandedMask); + }; + if (CanonicalizeSelectConstant(I, 1, DemandedMask) || + CanonicalizeSelectConstant(I, 2, DemandedMask)) + return I; + + // Only known if known in both the LHS and RHS. + Known = KnownBits::commonBits(LHSKnown, RHSKnown); + break; + } + case Instruction::Trunc: { + // If we do not demand the high bits of a right-shifted and truncated value, + // then we may be able to truncate it before the shift. + Value *X; + const APInt *C; + if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { + // The shift amount must be valid (not poison) in the narrow type, and + // it must not be greater than the high bits demanded of the result. + if (C->ult(VTy->getScalarSizeInBits()) && + C->ule(DemandedMask.countLeadingZeros())) { + // trunc (lshr X, C) --> lshr (trunc X), C + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(I); + Value *Trunc = Builder.CreateTrunc(X, VTy); + return Builder.CreateLShr(Trunc, C->getZExtValue()); + } + } + } + LLVM_FALLTHROUGH; + case Instruction::ZExt: { + unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); + + APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth); + KnownBits InputKnown(SrcBitWidth); + if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) + return I; + assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?"); + Known = InputKnown.zextOrTrunc(BitWidth); + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + break; + } + case Instruction::BitCast: + if (!I->getOperand(0)->getType()->isIntOrIntVectorTy()) + return nullptr; // vector->int or fp->int? + + if (auto *DstVTy = dyn_cast<VectorType>(VTy)) { + if (auto *SrcVTy = dyn_cast<VectorType>(I->getOperand(0)->getType())) { + if (cast<FixedVectorType>(DstVTy)->getNumElements() != + cast<FixedVectorType>(SrcVTy)->getNumElements()) + // Don't touch a bitcast between vectors of different element counts. + return nullptr; + } else + // Don't touch a scalar-to-vector bitcast. + return nullptr; + } else if (I->getOperand(0)->getType()->isVectorTy()) + // Don't touch a vector-to-scalar bitcast. + return nullptr; + + if (SimplifyDemandedBits(I, 0, DemandedMask, Known, Depth + 1)) + return I; + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + break; + case Instruction::SExt: { + // Compute the bits in the result that are not present in the input. + unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); + + APInt InputDemandedBits = DemandedMask.trunc(SrcBitWidth); + + // If any of the sign extended bits are demanded, we know that the sign + // bit is demanded. + if (DemandedMask.getActiveBits() > SrcBitWidth) + InputDemandedBits.setBit(SrcBitWidth-1); + + KnownBits InputKnown(SrcBitWidth); + if (SimplifyDemandedBits(I, 0, InputDemandedBits, InputKnown, Depth + 1)) + return I; + + // If the input sign bit is known zero, or if the NewBits are not demanded + // convert this into a zero extension. + if (InputKnown.isNonNegative() || + DemandedMask.getActiveBits() <= SrcBitWidth) { + // Convert to ZExt cast. + CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy, I->getName()); + return InsertNewInstWith(NewCast, *I); + } + + // If the sign bit of the input is known set or clear, then we know the + // top bits of the result. + Known = InputKnown.sext(BitWidth); + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + break; + } + case Instruction::Add: + if ((DemandedMask & 1) == 0) { + // If we do not need the low bit, try to convert bool math to logic: + // add iN (zext i1 X), (sext i1 Y) --> sext (~X & Y) to iN + Value *X, *Y; + if (match(I, m_c_Add(m_OneUse(m_ZExt(m_Value(X))), + m_OneUse(m_SExt(m_Value(Y))))) && + X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) { + // Truth table for inputs and output signbits: + // X:0 | X:1 + // ---------- + // Y:0 | 0 | 0 | + // Y:1 | -1 | 0 | + // ---------- + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(I); + Value *AndNot = Builder.CreateAnd(Builder.CreateNot(X), Y); + return Builder.CreateSExt(AndNot, VTy); + } + + // add iN (sext i1 X), (sext i1 Y) --> sext (X | Y) to iN + // TODO: Relax the one-use checks because we are removing an instruction? + if (match(I, m_Add(m_OneUse(m_SExt(m_Value(X))), + m_OneUse(m_SExt(m_Value(Y))))) && + X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) { + // Truth table for inputs and output signbits: + // X:0 | X:1 + // ----------- + // Y:0 | -1 | -1 | + // Y:1 | -1 | 0 | + // ----------- + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(I); + Value *Or = Builder.CreateOr(X, Y); + return Builder.CreateSExt(Or, VTy); + } + } + LLVM_FALLTHROUGH; + case Instruction::Sub: { + APInt DemandedFromOps; + if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps)) + return I; + + // If we are known to be adding/subtracting zeros to every bit below + // the highest demanded bit, we just return the other side. + if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + // We can't do this with the LHS for subtraction, unless we are only + // demanding the LSB. + if ((I->getOpcode() == Instruction::Add || DemandedFromOps.isOne()) && + DemandedFromOps.isSubsetOf(LHSKnown.Zero)) + return I->getOperand(1); + + // Otherwise just compute the known bits of the result. + bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + Known = KnownBits::computeForAddSub(I->getOpcode() == Instruction::Add, + NSW, LHSKnown, RHSKnown); + break; + } + case Instruction::Mul: { + APInt DemandedFromOps; + if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps)) + return I; + + if (DemandedMask.isPowerOf2()) { + // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1. + // If we demand exactly one bit N and we have "X * (C' << N)" where C' is + // odd (has LSB set), then the left-shifted low bit of X is the answer. + unsigned CTZ = DemandedMask.countTrailingZeros(); + const APInt *C; + if (match(I->getOperand(1), m_APInt(C)) && + C->countTrailingZeros() == CTZ) { + Constant *ShiftC = ConstantInt::get(VTy, CTZ); + Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC); + return InsertNewInstWith(Shl, *I); + } + } + // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because: + // X * X is odd iff X is odd. + // 'Quadratic Reciprocity': X * X -> 0 for bit[1] + if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) { + Constant *One = ConstantInt::get(VTy, 1); + Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One); + return InsertNewInstWith(And1, *I); + } + + computeKnownBits(I, Known, Depth, CxtI); + break; + } + case Instruction::Shl: { + const APInt *SA; + if (match(I->getOperand(1), m_APInt(SA))) { + const APInt *ShrAmt; + if (match(I->getOperand(0), m_Shr(m_Value(), m_APInt(ShrAmt)))) + if (Instruction *Shr = dyn_cast<Instruction>(I->getOperand(0))) + if (Value *R = simplifyShrShlDemandedBits(Shr, *ShrAmt, I, *SA, + DemandedMask, Known)) + return R; + + // TODO: If we only want bits that already match the signbit then we don't + // need to shift. + + // If we can pre-shift a right-shifted constant to the left without + // losing any high bits amd we don't demand the low bits, then eliminate + // the left-shift: + // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); + Value *X; + Constant *C; + if (DemandedMask.countTrailingZeros() >= ShiftAmt && + match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) { + Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt); + Constant *NewC = ConstantExpr::getShl(C, LeftShiftAmtC); + if (ConstantExpr::getLShr(NewC, LeftShiftAmtC) == C) { + Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X); + return InsertNewInstWith(Lshr, *I); + } + } + + APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt)); + + // If the shift is NUW/NSW, then it does demand the high bits. + ShlOperator *IOp = cast<ShlOperator>(I); + if (IOp->hasNoSignedWrap()) + DemandedMaskIn.setHighBits(ShiftAmt+1); + else if (IOp->hasNoUnsignedWrap()) + DemandedMaskIn.setHighBits(ShiftAmt); + + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) + return I; + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + + bool SignBitZero = Known.Zero.isSignBitSet(); + bool SignBitOne = Known.One.isSignBitSet(); + Known.Zero <<= ShiftAmt; + Known.One <<= ShiftAmt; + // low bits known zero. + if (ShiftAmt) + Known.Zero.setLowBits(ShiftAmt); + + // If this shift has "nsw" keyword, then the result is either a poison + // value or has the same sign bit as the first operand. + if (IOp->hasNoSignedWrap()) { + if (SignBitZero) + Known.Zero.setSignBit(); + else if (SignBitOne) + Known.One.setSignBit(); + if (Known.hasConflict()) + return UndefValue::get(VTy); + } + } else { + // This is a variable shift, so we can't shift the demand mask by a known + // amount. But if we are not demanding high bits, then we are not + // demanding those bits from the pre-shifted operand either. + if (unsigned CTLZ = DemandedMask.countLeadingZeros()) { + APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ)); + if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1)) { + // We can't guarantee that nsw/nuw hold after simplifying the operand. + I->dropPoisonGeneratingFlags(); + return I; + } + } + computeKnownBits(I, Known, Depth, CxtI); + } + break; + } + case Instruction::LShr: { + const APInt *SA; + if (match(I->getOperand(1), m_APInt(SA))) { + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); + + // If we are just demanding the shifted sign bit and below, then this can + // be treated as an ASHR in disguise. + if (DemandedMask.countLeadingZeros() >= ShiftAmt) { + // If we only want bits that already match the signbit then we don't + // need to shift. + unsigned NumHiDemandedBits = + BitWidth - DemandedMask.countTrailingZeros(); + unsigned SignBits = + ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); + if (SignBits >= NumHiDemandedBits) + return I->getOperand(0); + + // If we can pre-shift a left-shifted constant to the right without + // losing any low bits (we already know we don't demand the high bits), + // then eliminate the right-shift: + // (C << X) >> RightShiftAmtC --> (C >> RightShiftAmtC) << X + Value *X; + Constant *C; + if (match(I->getOperand(0), m_Shl(m_ImmConstant(C), m_Value(X)))) { + Constant *RightShiftAmtC = ConstantInt::get(VTy, ShiftAmt); + Constant *NewC = ConstantExpr::getLShr(C, RightShiftAmtC); + if (ConstantExpr::getShl(NewC, RightShiftAmtC) == C) { + Instruction *Shl = BinaryOperator::CreateShl(NewC, X); + return InsertNewInstWith(Shl, *I); + } + } + } + + // Unsigned shift right. + APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); + + // If the shift is exact, then it does demand the low bits (and knows that + // they are zero). + if (cast<LShrOperator>(I)->isExact()) + DemandedMaskIn.setLowBits(ShiftAmt); + + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) + return I; + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + Known.Zero.lshrInPlace(ShiftAmt); + Known.One.lshrInPlace(ShiftAmt); + if (ShiftAmt) + Known.Zero.setHighBits(ShiftAmt); // high bits known zero. + } else { + computeKnownBits(I, Known, Depth, CxtI); + } + break; + } + case Instruction::AShr: { + unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); + + // If we only want bits that already match the signbit then we don't need + // to shift. + unsigned NumHiDemandedBits = BitWidth - DemandedMask.countTrailingZeros(); + if (SignBits >= NumHiDemandedBits) + return I->getOperand(0); + + // If this is an arithmetic shift right and only the low-bit is set, we can + // always convert this into a logical shr, even if the shift amount is + // variable. The low bit of the shift cannot be an input sign bit unless + // the shift amount is >= the size of the datatype, which is undefined. + if (DemandedMask.isOne()) { + // Perform the logical shift right. + Instruction *NewVal = BinaryOperator::CreateLShr( + I->getOperand(0), I->getOperand(1), I->getName()); + return InsertNewInstWith(NewVal, *I); + } + + const APInt *SA; + if (match(I->getOperand(1), m_APInt(SA))) { + uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1); + + // Signed shift right. + APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); + // If any of the high bits are demanded, we should set the sign bit as + // demanded. + if (DemandedMask.countLeadingZeros() <= ShiftAmt) + DemandedMaskIn.setSignBit(); + + // If the shift is exact, then it does demand the low bits (and knows that + // they are zero). + if (cast<AShrOperator>(I)->isExact()) + DemandedMaskIn.setLowBits(ShiftAmt); + + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) + return I; + + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + // Compute the new bits that are at the top now plus sign bits. + APInt HighBits(APInt::getHighBitsSet( + BitWidth, std::min(SignBits + ShiftAmt - 1, BitWidth))); + Known.Zero.lshrInPlace(ShiftAmt); + Known.One.lshrInPlace(ShiftAmt); + + // If the input sign bit is known to be zero, or if none of the top bits + // are demanded, turn this into an unsigned shift right. + assert(BitWidth > ShiftAmt && "Shift amount not saturated?"); + if (Known.Zero[BitWidth-ShiftAmt-1] || + !DemandedMask.intersects(HighBits)) { + BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0), + I->getOperand(1)); + LShr->setIsExact(cast<BinaryOperator>(I)->isExact()); + return InsertNewInstWith(LShr, *I); + } else if (Known.One[BitWidth-ShiftAmt-1]) { // New bits are known one. + Known.One |= HighBits; + } + } else { + computeKnownBits(I, Known, Depth, CxtI); + } + break; + } + case Instruction::UDiv: { + // UDiv doesn't demand low bits that are zero in the divisor. + const APInt *SA; + if (match(I->getOperand(1), m_APInt(SA))) { + // If the shift is exact, then it does demand the low bits. + if (cast<UDivOperator>(I)->isExact()) + break; + + // FIXME: Take the demanded mask of the result into account. + unsigned RHSTrailingZeros = SA->countTrailingZeros(); + APInt DemandedMaskIn = + APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros); + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) + return I; + + // Propagate zero bits from the input. + Known.Zero.setHighBits(std::min( + BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros)); + } else { + computeKnownBits(I, Known, Depth, CxtI); + } + break; + } + case Instruction::SRem: { + const APInt *Rem; + if (match(I->getOperand(1), m_APInt(Rem))) { + // X % -1 demands all the bits because we don't want to introduce + // INT_MIN % -1 (== undef) by accident. + if (Rem->isAllOnes()) + break; + APInt RA = Rem->abs(); + if (RA.isPowerOf2()) { + if (DemandedMask.ult(RA)) // srem won't affect demanded bits + return I->getOperand(0); + + APInt LowBits = RA - 1; + APInt Mask2 = LowBits | APInt::getSignMask(BitWidth); + if (SimplifyDemandedBits(I, 0, Mask2, LHSKnown, Depth + 1)) + return I; + + // The low bits of LHS are unchanged by the srem. + Known.Zero = LHSKnown.Zero & LowBits; + Known.One = LHSKnown.One & LowBits; + + // If LHS is non-negative or has all low bits zero, then the upper bits + // are all zero. + if (LHSKnown.isNonNegative() || LowBits.isSubsetOf(LHSKnown.Zero)) + Known.Zero |= ~LowBits; + + // If LHS is negative and not all low bits are zero, then the upper bits + // are all one. + if (LHSKnown.isNegative() && LowBits.intersects(LHSKnown.One)) + Known.One |= ~LowBits; + + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + break; + } + } + + // The sign bit is the LHS's sign bit, except when the result of the + // remainder is zero. + if (DemandedMask.isSignBitSet()) { + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); + // If it's known zero, our sign bit is also zero. + if (LHSKnown.isNonNegative()) + Known.makeNonNegative(); + } + break; + } + case Instruction::URem: { + KnownBits Known2(BitWidth); + APInt AllOnes = APInt::getAllOnes(BitWidth); + if (SimplifyDemandedBits(I, 0, AllOnes, Known2, Depth + 1) || + SimplifyDemandedBits(I, 1, AllOnes, Known2, Depth + 1)) + return I; + + unsigned Leaders = Known2.countMinLeadingZeros(); + Known.Zero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask; + break; + } + case Instruction::Call: { + bool KnownBitsComputed = false; + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + case Intrinsic::abs: { + if (DemandedMask == 1) + return II->getArgOperand(0); + break; + } + case Intrinsic::ctpop: { + // Checking if the number of clear bits is odd (parity)? If the type has + // an even number of bits, that's the same as checking if the number of + // set bits is odd, so we can eliminate the 'not' op. + Value *X; + if (DemandedMask == 1 && VTy->getScalarSizeInBits() % 2 == 0 && + match(II->getArgOperand(0), m_Not(m_Value(X)))) { + Function *Ctpop = Intrinsic::getDeclaration( + II->getModule(), Intrinsic::ctpop, VTy); + return InsertNewInstWith(CallInst::Create(Ctpop, {X}), *I); + } + break; + } + case Intrinsic::bswap: { + // If the only bits demanded come from one byte of the bswap result, + // just shift the input byte into position to eliminate the bswap. + unsigned NLZ = DemandedMask.countLeadingZeros(); + unsigned NTZ = DemandedMask.countTrailingZeros(); + + // Round NTZ down to the next byte. If we have 11 trailing zeros, then + // we need all the bits down to bit 8. Likewise, round NLZ. If we + // have 14 leading zeros, round to 8. + NLZ = alignDown(NLZ, 8); + NTZ = alignDown(NTZ, 8); + // If we need exactly one byte, we can do this transformation. + if (BitWidth - NLZ - NTZ == 8) { + // Replace this with either a left or right shift to get the byte into + // the right place. + Instruction *NewVal; + if (NLZ > NTZ) + NewVal = BinaryOperator::CreateLShr( + II->getArgOperand(0), ConstantInt::get(VTy, NLZ - NTZ)); + else + NewVal = BinaryOperator::CreateShl( + II->getArgOperand(0), ConstantInt::get(VTy, NTZ - NLZ)); + NewVal->takeName(I); + return InsertNewInstWith(NewVal, *I); + } + break; + } + case Intrinsic::fshr: + case Intrinsic::fshl: { + const APInt *SA; + if (!match(I->getOperand(2), m_APInt(SA))) + break; + + // Normalize to funnel shift left. APInt shifts of BitWidth are well- + // defined, so no need to special-case zero shifts here. + uint64_t ShiftAmt = SA->urem(BitWidth); + if (II->getIntrinsicID() == Intrinsic::fshr) + ShiftAmt = BitWidth - ShiftAmt; + + APInt DemandedMaskLHS(DemandedMask.lshr(ShiftAmt)); + APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt)); + if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1)) + return I; + + Known.Zero = LHSKnown.Zero.shl(ShiftAmt) | + RHSKnown.Zero.lshr(BitWidth - ShiftAmt); + Known.One = LHSKnown.One.shl(ShiftAmt) | + RHSKnown.One.lshr(BitWidth - ShiftAmt); + KnownBitsComputed = true; + break; + } + case Intrinsic::umax: { + // UMax(A, C) == A if ... + // The lowest non-zero bit of DemandMask is higher than the highest + // non-zero bit of C. + const APInt *C; + unsigned CTZ = DemandedMask.countTrailingZeros(); + if (match(II->getArgOperand(1), m_APInt(C)) && + CTZ >= C->getActiveBits()) + return II->getArgOperand(0); + break; + } + case Intrinsic::umin: { + // UMin(A, C) == A if ... + // The lowest non-zero bit of DemandMask is higher than the highest + // non-one bit of C. + // This comes from using DeMorgans on the above umax example. + const APInt *C; + unsigned CTZ = DemandedMask.countTrailingZeros(); + if (match(II->getArgOperand(1), m_APInt(C)) && + CTZ >= C->getBitWidth() - C->countLeadingOnes()) + return II->getArgOperand(0); + break; + } + default: { + // Handle target specific intrinsics + Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( + *II, DemandedMask, Known, KnownBitsComputed); + if (V) + return V.getValue(); + break; + } + } + } + + if (!KnownBitsComputed) + computeKnownBits(V, Known, Depth, CxtI); + break; + } + } + + // If the client is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(Known.Zero|Known.One)) + return Constant::getIntegerValue(VTy, Known.One); + return nullptr; +} + +/// Helper routine of SimplifyDemandedUseBits. It computes Known +/// bits. It also tries to handle simplifications that can be done based on +/// DemandedMask, but without modifying the Instruction. +Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( + Instruction *I, const APInt &DemandedMask, KnownBits &Known, unsigned Depth, + Instruction *CxtI) { + unsigned BitWidth = DemandedMask.getBitWidth(); + Type *ITy = I->getType(); + + KnownBits LHSKnown(BitWidth); + KnownBits RHSKnown(BitWidth); + + // Despite the fact that we can't simplify this instruction in all User's + // context, we can at least compute the known bits, and we can + // do simplifications that apply to *just* the one user if we know that + // this instruction has a simpler value in that context. + switch (I->getOpcode()) { + case Instruction::And: { + // If either the LHS or the RHS are Zero, the result is zero. + computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, + CxtI); + + Known = LHSKnown & RHSKnown; + + // If the client is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(ITy, Known.One); + + // If all of the demanded bits are known 1 on one side, return the other. + // These bits cannot contribute to the result of the 'and' in this + // context. + if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One)) + return I->getOperand(0); + if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One)) + return I->getOperand(1); + + break; + } + case Instruction::Or: { + // We can simplify (X|Y) -> X or Y in the user's context if we know that + // only bits from X or Y are demanded. + + // If either the LHS or the RHS are One, the result is One. + computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, + CxtI); + + Known = LHSKnown | RHSKnown; + + // If the client is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(ITy, Known.One); + + // If all of the demanded bits are known zero on one side, return the + // other. These bits cannot contribute to the result of the 'or' in this + // context. + if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero)) + return I->getOperand(0); + if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) + return I->getOperand(1); + + break; + } + case Instruction::Xor: { + // We can simplify (X^Y) -> X or Y in the user's context if we know that + // only bits from X or Y are demanded. + + computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, + CxtI); + + Known = LHSKnown ^ RHSKnown; + + // If the client is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(ITy, Known.One); + + // If all of the demanded bits are known zero on one side, return the + // other. + if (DemandedMask.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + if (DemandedMask.isSubsetOf(LHSKnown.Zero)) + return I->getOperand(1); + + break; + } + case Instruction::AShr: { + // Compute the Known bits to simplify things downstream. + computeKnownBits(I, Known, Depth, CxtI); + + // If this user is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(ITy, Known.One); + + // If the right shift operand 0 is a result of a left shift by the same + // amount, this is probably a zero/sign extension, which may be unnecessary, + // if we do not demand any of the new sign bits. So, return the original + // operand instead. + const APInt *ShiftRC; + const APInt *ShiftLC; + Value *X; + unsigned BitWidth = DemandedMask.getBitWidth(); + if (match(I, + m_AShr(m_Shl(m_Value(X), m_APInt(ShiftLC)), m_APInt(ShiftRC))) && + ShiftLC == ShiftRC && ShiftLC->ult(BitWidth) && + DemandedMask.isSubsetOf(APInt::getLowBitsSet( + BitWidth, BitWidth - ShiftRC->getZExtValue()))) { + return X; + } + + break; + } + default: + // Compute the Known bits to simplify things downstream. + computeKnownBits(I, Known, Depth, CxtI); + + // If this user is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(Known.Zero|Known.One)) + return Constant::getIntegerValue(ITy, Known.One); + + break; + } + + return nullptr; +} + +/// Helper routine of SimplifyDemandedUseBits. It tries to simplify +/// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into +/// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign +/// of "C2-C1". +/// +/// Suppose E1 and E2 are generally different in bits S={bm, bm+1, +/// ..., bn}, without considering the specific value X is holding. +/// This transformation is legal iff one of following conditions is hold: +/// 1) All the bit in S are 0, in this case E1 == E2. +/// 2) We don't care those bits in S, per the input DemandedMask. +/// 3) Combination of 1) and 2). Some bits in S are 0, and we don't care the +/// rest bits. +/// +/// Currently we only test condition 2). +/// +/// As with SimplifyDemandedUseBits, it returns NULL if the simplification was +/// not successful. +Value *InstCombinerImpl::simplifyShrShlDemandedBits( + Instruction *Shr, const APInt &ShrOp1, Instruction *Shl, + const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known) { + if (!ShlOp1 || !ShrOp1) + return nullptr; // No-op. + + Value *VarX = Shr->getOperand(0); + Type *Ty = VarX->getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + if (ShlOp1.uge(BitWidth) || ShrOp1.uge(BitWidth)) + return nullptr; // Undef. + + unsigned ShlAmt = ShlOp1.getZExtValue(); + unsigned ShrAmt = ShrOp1.getZExtValue(); + + Known.One.clearAllBits(); + Known.Zero.setLowBits(ShlAmt - 1); + Known.Zero &= DemandedMask; + + APInt BitMask1(APInt::getAllOnes(BitWidth)); + APInt BitMask2(APInt::getAllOnes(BitWidth)); + + bool isLshr = (Shr->getOpcode() == Instruction::LShr); + BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) : + (BitMask1.ashr(ShrAmt) << ShlAmt); + + if (ShrAmt <= ShlAmt) { + BitMask2 <<= (ShlAmt - ShrAmt); + } else { + BitMask2 = isLshr ? BitMask2.lshr(ShrAmt - ShlAmt): + BitMask2.ashr(ShrAmt - ShlAmt); + } + + // Check if condition-2 (see the comment to this function) is satified. + if ((BitMask1 & DemandedMask) == (BitMask2 & DemandedMask)) { + if (ShrAmt == ShlAmt) + return VarX; + + if (!Shr->hasOneUse()) + return nullptr; + + BinaryOperator *New; + if (ShrAmt < ShlAmt) { + Constant *Amt = ConstantInt::get(VarX->getType(), ShlAmt - ShrAmt); + New = BinaryOperator::CreateShl(VarX, Amt); + BinaryOperator *Orig = cast<BinaryOperator>(Shl); + New->setHasNoSignedWrap(Orig->hasNoSignedWrap()); + New->setHasNoUnsignedWrap(Orig->hasNoUnsignedWrap()); + } else { + Constant *Amt = ConstantInt::get(VarX->getType(), ShrAmt - ShlAmt); + New = isLshr ? BinaryOperator::CreateLShr(VarX, Amt) : + BinaryOperator::CreateAShr(VarX, Amt); + if (cast<BinaryOperator>(Shr)->isExact()) + New->setIsExact(true); + } + + return InsertNewInstWith(New, *Shl); + } + + return nullptr; +} + +/// The specified value produces a vector with any number of elements. +/// This method analyzes which elements of the operand are undef or poison and +/// returns that information in UndefElts. +/// +/// DemandedElts contains the set of elements that are actually used by the +/// caller, and by default (AllowMultipleUsers equals false) the value is +/// simplified only if it has a single caller. If AllowMultipleUsers is set +/// to true, DemandedElts refers to the union of sets of elements that are +/// used by all callers. +/// +/// If the information about demanded elements can be used to simplify the +/// operation, the operation is simplified, then the resultant value is +/// returned. This returns null if no change was made. +Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, + APInt DemandedElts, + APInt &UndefElts, + unsigned Depth, + bool AllowMultipleUsers) { + // Cannot analyze scalable type. The number of vector elements is not a + // compile-time constant. + if (isa<ScalableVectorType>(V->getType())) + return nullptr; + + unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); + APInt EltMask(APInt::getAllOnes(VWidth)); + assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); + + if (match(V, m_Undef())) { + // If the entire vector is undef or poison, just return this info. + UndefElts = EltMask; + return nullptr; + } + + if (DemandedElts.isZero()) { // If nothing is demanded, provide poison. + UndefElts = EltMask; + return PoisonValue::get(V->getType()); + } + + UndefElts = 0; + + if (auto *C = dyn_cast<Constant>(V)) { + // Check if this is identity. If so, return 0 since we are not simplifying + // anything. + if (DemandedElts.isAllOnes()) + return nullptr; + + Type *EltTy = cast<VectorType>(V->getType())->getElementType(); + Constant *Poison = PoisonValue::get(EltTy); + SmallVector<Constant*, 16> Elts; + for (unsigned i = 0; i != VWidth; ++i) { + if (!DemandedElts[i]) { // If not demanded, set to poison. + Elts.push_back(Poison); + UndefElts.setBit(i); + continue; + } + + Constant *Elt = C->getAggregateElement(i); + if (!Elt) return nullptr; + + Elts.push_back(Elt); + if (isa<UndefValue>(Elt)) // Already undef or poison. + UndefElts.setBit(i); + } + + // If we changed the constant, return it. + Constant *NewCV = ConstantVector::get(Elts); + return NewCV != C ? NewCV : nullptr; + } + + // Limit search depth. + if (Depth == 10) + return nullptr; + + if (!AllowMultipleUsers) { + // If multiple users are using the root value, proceed with + // simplification conservatively assuming that all elements + // are needed. + if (!V->hasOneUse()) { + // Quit if we find multiple users of a non-root value though. + // They'll be handled when it's their turn to be visited by + // the main instcombine process. + if (Depth != 0) + // TODO: Just compute the UndefElts information recursively. + return nullptr; + + // Conservatively assume that all elements are needed. + DemandedElts = EltMask; + } + } + + Instruction *I = dyn_cast<Instruction>(V); + if (!I) return nullptr; // Only analyze instructions. + + bool MadeChange = false; + auto simplifyAndSetOp = [&](Instruction *Inst, unsigned OpNum, + APInt Demanded, APInt &Undef) { + auto *II = dyn_cast<IntrinsicInst>(Inst); + Value *Op = II ? II->getArgOperand(OpNum) : Inst->getOperand(OpNum); + if (Value *V = SimplifyDemandedVectorElts(Op, Demanded, Undef, Depth + 1)) { + replaceOperand(*Inst, OpNum, V); + MadeChange = true; + } + }; + + APInt UndefElts2(VWidth, 0); + APInt UndefElts3(VWidth, 0); + switch (I->getOpcode()) { + default: break; + + case Instruction::GetElementPtr: { + // The LangRef requires that struct geps have all constant indices. As + // such, we can't convert any operand to partial undef. + auto mayIndexStructType = [](GetElementPtrInst &GEP) { + for (auto I = gep_type_begin(GEP), E = gep_type_end(GEP); + I != E; I++) + if (I.isStruct()) + return true; + return false; + }; + if (mayIndexStructType(cast<GetElementPtrInst>(*I))) + break; + + // Conservatively track the demanded elements back through any vector + // operands we may have. We know there must be at least one, or we + // wouldn't have a vector result to get here. Note that we intentionally + // merge the undef bits here since gepping with either an poison base or + // index results in poison. + for (unsigned i = 0; i < I->getNumOperands(); i++) { + if (i == 0 ? match(I->getOperand(i), m_Undef()) + : match(I->getOperand(i), m_Poison())) { + // If the entire vector is undefined, just return this info. + UndefElts = EltMask; + return nullptr; + } + if (I->getOperand(i)->getType()->isVectorTy()) { + APInt UndefEltsOp(VWidth, 0); + simplifyAndSetOp(I, i, DemandedElts, UndefEltsOp); + // gep(x, undef) is not undef, so skip considering idx ops here + // Note that we could propagate poison, but we can't distinguish between + // undef & poison bits ATM + if (i == 0) + UndefElts |= UndefEltsOp; + } + } + + break; + } + case Instruction::InsertElement: { + // If this is a variable index, we don't know which element it overwrites. + // demand exactly the same input as we produce. + ConstantInt *Idx = dyn_cast<ConstantInt>(I->getOperand(2)); + if (!Idx) { + // Note that we can't propagate undef elt info, because we don't know + // which elt is getting updated. + simplifyAndSetOp(I, 0, DemandedElts, UndefElts2); + break; + } + + // The element inserted overwrites whatever was there, so the input demanded + // set is simpler than the output set. + unsigned IdxNo = Idx->getZExtValue(); + APInt PreInsertDemandedElts = DemandedElts; + if (IdxNo < VWidth) + PreInsertDemandedElts.clearBit(IdxNo); + + // If we only demand the element that is being inserted and that element + // was extracted from the same index in another vector with the same type, + // replace this insert with that other vector. + // Note: This is attempted before the call to simplifyAndSetOp because that + // may change UndefElts to a value that does not match with Vec. + Value *Vec; + if (PreInsertDemandedElts == 0 && + match(I->getOperand(1), + m_ExtractElt(m_Value(Vec), m_SpecificInt(IdxNo))) && + Vec->getType() == I->getType()) { + return Vec; + } + + simplifyAndSetOp(I, 0, PreInsertDemandedElts, UndefElts); + + // If this is inserting an element that isn't demanded, remove this + // insertelement. + if (IdxNo >= VWidth || !DemandedElts[IdxNo]) { + Worklist.push(I); + return I->getOperand(0); + } + + // The inserted element is defined. + UndefElts.clearBit(IdxNo); + break; + } + case Instruction::ShuffleVector: { + auto *Shuffle = cast<ShuffleVectorInst>(I); + assert(Shuffle->getOperand(0)->getType() == + Shuffle->getOperand(1)->getType() && + "Expected shuffle operands to have same type"); + unsigned OpWidth = cast<FixedVectorType>(Shuffle->getOperand(0)->getType()) + ->getNumElements(); + // Handle trivial case of a splat. Only check the first element of LHS + // operand. + if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) && + DemandedElts.isAllOnes()) { + if (!match(I->getOperand(1), m_Undef())) { + I->setOperand(1, PoisonValue::get(I->getOperand(1)->getType())); + MadeChange = true; + } + APInt LeftDemanded(OpWidth, 1); + APInt LHSUndefElts(OpWidth, 0); + simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts); + if (LHSUndefElts[0]) + UndefElts = EltMask; + else + UndefElts.clearAllBits(); + break; + } + + APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0); + for (unsigned i = 0; i < VWidth; i++) { + if (DemandedElts[i]) { + unsigned MaskVal = Shuffle->getMaskValue(i); + if (MaskVal != -1u) { + assert(MaskVal < OpWidth * 2 && + "shufflevector mask index out of range!"); + if (MaskVal < OpWidth) + LeftDemanded.setBit(MaskVal); + else + RightDemanded.setBit(MaskVal - OpWidth); + } + } + } + + APInt LHSUndefElts(OpWidth, 0); + simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts); + + APInt RHSUndefElts(OpWidth, 0); + simplifyAndSetOp(I, 1, RightDemanded, RHSUndefElts); + + // If this shuffle does not change the vector length and the elements + // demanded by this shuffle are an identity mask, then this shuffle is + // unnecessary. + // + // We are assuming canonical form for the mask, so the source vector is + // operand 0 and operand 1 is not used. + // + // Note that if an element is demanded and this shuffle mask is undefined + // for that element, then the shuffle is not considered an identity + // operation. The shuffle prevents poison from the operand vector from + // leaking to the result by replacing poison with an undefined value. + if (VWidth == OpWidth) { + bool IsIdentityShuffle = true; + for (unsigned i = 0; i < VWidth; i++) { + unsigned MaskVal = Shuffle->getMaskValue(i); + if (DemandedElts[i] && i != MaskVal) { + IsIdentityShuffle = false; + break; + } + } + if (IsIdentityShuffle) + return Shuffle->getOperand(0); + } + + bool NewUndefElts = false; + unsigned LHSIdx = -1u, LHSValIdx = -1u; + unsigned RHSIdx = -1u, RHSValIdx = -1u; + bool LHSUniform = true; + bool RHSUniform = true; + for (unsigned i = 0; i < VWidth; i++) { + unsigned MaskVal = Shuffle->getMaskValue(i); + if (MaskVal == -1u) { + UndefElts.setBit(i); + } else if (!DemandedElts[i]) { + NewUndefElts = true; + UndefElts.setBit(i); + } else if (MaskVal < OpWidth) { + if (LHSUndefElts[MaskVal]) { + NewUndefElts = true; + UndefElts.setBit(i); + } else { + LHSIdx = LHSIdx == -1u ? i : OpWidth; + LHSValIdx = LHSValIdx == -1u ? MaskVal : OpWidth; + LHSUniform = LHSUniform && (MaskVal == i); + } + } else { + if (RHSUndefElts[MaskVal - OpWidth]) { + NewUndefElts = true; + UndefElts.setBit(i); + } else { + RHSIdx = RHSIdx == -1u ? i : OpWidth; + RHSValIdx = RHSValIdx == -1u ? MaskVal - OpWidth : OpWidth; + RHSUniform = RHSUniform && (MaskVal - OpWidth == i); + } + } + } + + // Try to transform shuffle with constant vector and single element from + // this constant vector to single insertelement instruction. + // shufflevector V, C, <v1, v2, .., ci, .., vm> -> + // insertelement V, C[ci], ci-n + if (OpWidth == + cast<FixedVectorType>(Shuffle->getType())->getNumElements()) { + Value *Op = nullptr; + Constant *Value = nullptr; + unsigned Idx = -1u; + + // Find constant vector with the single element in shuffle (LHS or RHS). + if (LHSIdx < OpWidth && RHSUniform) { + if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(0))) { + Op = Shuffle->getOperand(1); + Value = CV->getOperand(LHSValIdx); + Idx = LHSIdx; + } + } + if (RHSIdx < OpWidth && LHSUniform) { + if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(1))) { + Op = Shuffle->getOperand(0); + Value = CV->getOperand(RHSValIdx); + Idx = RHSIdx; + } + } + // Found constant vector with single element - convert to insertelement. + if (Op && Value) { + Instruction *New = InsertElementInst::Create( + Op, Value, ConstantInt::get(Type::getInt32Ty(I->getContext()), Idx), + Shuffle->getName()); + InsertNewInstWith(New, *Shuffle); + return New; + } + } + if (NewUndefElts) { + // Add additional discovered undefs. + SmallVector<int, 16> Elts; + for (unsigned i = 0; i < VWidth; ++i) { + if (UndefElts[i]) + Elts.push_back(UndefMaskElem); + else + Elts.push_back(Shuffle->getMaskValue(i)); + } + Shuffle->setShuffleMask(Elts); + MadeChange = true; + } + break; + } + case Instruction::Select: { + // If this is a vector select, try to transform the select condition based + // on the current demanded elements. + SelectInst *Sel = cast<SelectInst>(I); + if (Sel->getCondition()->getType()->isVectorTy()) { + // TODO: We are not doing anything with UndefElts based on this call. + // It is overwritten below based on the other select operands. If an + // element of the select condition is known undef, then we are free to + // choose the output value from either arm of the select. If we know that + // one of those values is undef, then the output can be undef. + simplifyAndSetOp(I, 0, DemandedElts, UndefElts); + } + + // Next, see if we can transform the arms of the select. + APInt DemandedLHS(DemandedElts), DemandedRHS(DemandedElts); + if (auto *CV = dyn_cast<ConstantVector>(Sel->getCondition())) { + for (unsigned i = 0; i < VWidth; i++) { + // isNullValue() always returns false when called on a ConstantExpr. + // Skip constant expressions to avoid propagating incorrect information. + Constant *CElt = CV->getAggregateElement(i); + if (isa<ConstantExpr>(CElt)) + continue; + // TODO: If a select condition element is undef, we can demand from + // either side. If one side is known undef, choosing that side would + // propagate undef. + if (CElt->isNullValue()) + DemandedLHS.clearBit(i); + else + DemandedRHS.clearBit(i); + } + } + + simplifyAndSetOp(I, 1, DemandedLHS, UndefElts2); + simplifyAndSetOp(I, 2, DemandedRHS, UndefElts3); + + // Output elements are undefined if the element from each arm is undefined. + // TODO: This can be improved. See comment in select condition handling. + UndefElts = UndefElts2 & UndefElts3; + break; + } + case Instruction::BitCast: { + // Vector->vector casts only. + VectorType *VTy = dyn_cast<VectorType>(I->getOperand(0)->getType()); + if (!VTy) break; + unsigned InVWidth = cast<FixedVectorType>(VTy)->getNumElements(); + APInt InputDemandedElts(InVWidth, 0); + UndefElts2 = APInt(InVWidth, 0); + unsigned Ratio; + + if (VWidth == InVWidth) { + // If we are converting from <4 x i32> -> <4 x f32>, we demand the same + // elements as are demanded of us. + Ratio = 1; + InputDemandedElts = DemandedElts; + } else if ((VWidth % InVWidth) == 0) { + // If the number of elements in the output is a multiple of the number of + // elements in the input then an input element is live if any of the + // corresponding output elements are live. + Ratio = VWidth / InVWidth; + for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) + if (DemandedElts[OutIdx]) + InputDemandedElts.setBit(OutIdx / Ratio); + } else if ((InVWidth % VWidth) == 0) { + // If the number of elements in the input is a multiple of the number of + // elements in the output then an input element is live if the + // corresponding output element is live. + Ratio = InVWidth / VWidth; + for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx) + if (DemandedElts[InIdx / Ratio]) + InputDemandedElts.setBit(InIdx); + } else { + // Unsupported so far. + break; + } + + simplifyAndSetOp(I, 0, InputDemandedElts, UndefElts2); + + if (VWidth == InVWidth) { + UndefElts = UndefElts2; + } else if ((VWidth % InVWidth) == 0) { + // If the number of elements in the output is a multiple of the number of + // elements in the input then an output element is undef if the + // corresponding input element is undef. + for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) + if (UndefElts2[OutIdx / Ratio]) + UndefElts.setBit(OutIdx); + } else if ((InVWidth % VWidth) == 0) { + // If the number of elements in the input is a multiple of the number of + // elements in the output then an output element is undef if all of the + // corresponding input elements are undef. + for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) { + APInt SubUndef = UndefElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio); + if (SubUndef.countPopulation() == Ratio) + UndefElts.setBit(OutIdx); + } + } else { + llvm_unreachable("Unimp"); + } + break; + } + case Instruction::FPTrunc: + case Instruction::FPExt: + simplifyAndSetOp(I, 0, DemandedElts, UndefElts); + break; + + case Instruction::Call: { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); + if (!II) break; + switch (II->getIntrinsicID()) { + case Intrinsic::masked_gather: // fallthrough + case Intrinsic::masked_load: { + // Subtlety: If we load from a pointer, the pointer must be valid + // regardless of whether the element is demanded. Doing otherwise risks + // segfaults which didn't exist in the original program. + APInt DemandedPtrs(APInt::getAllOnes(VWidth)), + DemandedPassThrough(DemandedElts); + if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2))) + for (unsigned i = 0; i < VWidth; i++) { + Constant *CElt = CV->getAggregateElement(i); + if (CElt->isNullValue()) + DemandedPtrs.clearBit(i); + else if (CElt->isAllOnesValue()) + DemandedPassThrough.clearBit(i); + } + if (II->getIntrinsicID() == Intrinsic::masked_gather) + simplifyAndSetOp(II, 0, DemandedPtrs, UndefElts2); + simplifyAndSetOp(II, 3, DemandedPassThrough, UndefElts3); + + // Output elements are undefined if the element from both sources are. + // TODO: can strengthen via mask as well. + UndefElts = UndefElts2 & UndefElts3; + break; + } + default: { + // Handle target specific intrinsics + Optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic( + *II, DemandedElts, UndefElts, UndefElts2, UndefElts3, + simplifyAndSetOp); + if (V) + return V.getValue(); + break; + } + } // switch on IntrinsicID + break; + } // case Call + } // switch on Opcode + + // TODO: We bail completely on integer div/rem and shifts because they have + // UB/poison potential, but that should be refined. + BinaryOperator *BO; + if (match(I, m_BinOp(BO)) && !BO->isIntDivRem() && !BO->isShift()) { + simplifyAndSetOp(I, 0, DemandedElts, UndefElts); + simplifyAndSetOp(I, 1, DemandedElts, UndefElts2); + + // Output elements are undefined if both are undefined. Consider things + // like undef & 0. The result is known zero, not undef. + UndefElts &= UndefElts2; + } + + // If we've proven all of the lanes undef, return an undef value. + // TODO: Intersect w/demanded lanes + if (UndefElts.isAllOnes()) + return UndefValue::get(I->getType());; + + return MadeChange ? I : nullptr; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp new file mode 100644 index 000000000000..22659a8e4951 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -0,0 +1,2914 @@ +//===- InstCombineVectorOps.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements instcombine for ExtractElement, InsertElement and +// ShuffleVector. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility> + +#define DEBUG_TYPE "instcombine" + +using namespace llvm; +using namespace PatternMatch; + +STATISTIC(NumAggregateReconstructionsSimplified, + "Number of aggregate reconstructions turned into reuse of the " + "original aggregate"); + +/// Return true if the value is cheaper to scalarize than it is to leave as a +/// vector operation. If the extract index \p EI is a constant integer then +/// some operations may be cheap to scalarize. +/// +/// FIXME: It's possible to create more instructions than previously existed. +static bool cheapToScalarize(Value *V, Value *EI) { + ConstantInt *CEI = dyn_cast<ConstantInt>(EI); + + // If we can pick a scalar constant value out of a vector, that is free. + if (auto *C = dyn_cast<Constant>(V)) + return CEI || C->getSplatValue(); + + if (CEI && match(V, m_Intrinsic<Intrinsic::experimental_stepvector>())) { + ElementCount EC = cast<VectorType>(V->getType())->getElementCount(); + // Index needs to be lower than the minimum size of the vector, because + // for scalable vector, the vector size is known at run time. + return CEI->getValue().ult(EC.getKnownMinValue()); + } + + // An insertelement to the same constant index as our extract will simplify + // to the scalar inserted element. An insertelement to a different constant + // index is irrelevant to our extract. + if (match(V, m_InsertElt(m_Value(), m_Value(), m_ConstantInt()))) + return CEI; + + if (match(V, m_OneUse(m_Load(m_Value())))) + return true; + + if (match(V, m_OneUse(m_UnOp()))) + return true; + + Value *V0, *V1; + if (match(V, m_OneUse(m_BinOp(m_Value(V0), m_Value(V1))))) + if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI)) + return true; + + CmpInst::Predicate UnusedPred; + if (match(V, m_OneUse(m_Cmp(UnusedPred, m_Value(V0), m_Value(V1))))) + if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI)) + return true; + + return false; +} + +// If we have a PHI node with a vector type that is only used to feed +// itself and be an operand of extractelement at a constant location, +// try to replace the PHI of the vector type with a PHI of a scalar type. +Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, + PHINode *PN) { + SmallVector<Instruction *, 2> Extracts; + // The users we want the PHI to have are: + // 1) The EI ExtractElement (we already know this) + // 2) Possibly more ExtractElements with the same index. + // 3) Another operand, which will feed back into the PHI. + Instruction *PHIUser = nullptr; + for (auto U : PN->users()) { + if (ExtractElementInst *EU = dyn_cast<ExtractElementInst>(U)) { + if (EI.getIndexOperand() == EU->getIndexOperand()) + Extracts.push_back(EU); + else + return nullptr; + } else if (!PHIUser) { + PHIUser = cast<Instruction>(U); + } else { + return nullptr; + } + } + + if (!PHIUser) + return nullptr; + + // Verify that this PHI user has one use, which is the PHI itself, + // and that it is a binary operation which is cheap to scalarize. + // otherwise return nullptr. + if (!PHIUser->hasOneUse() || !(PHIUser->user_back() == PN) || + !(isa<BinaryOperator>(PHIUser)) || + !cheapToScalarize(PHIUser, EI.getIndexOperand())) + return nullptr; + + // Create a scalar PHI node that will replace the vector PHI node + // just before the current PHI node. + PHINode *scalarPHI = cast<PHINode>(InsertNewInstWith( + PHINode::Create(EI.getType(), PN->getNumIncomingValues(), ""), *PN)); + // Scalarize each PHI operand. + for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { + Value *PHIInVal = PN->getIncomingValue(i); + BasicBlock *inBB = PN->getIncomingBlock(i); + Value *Elt = EI.getIndexOperand(); + // If the operand is the PHI induction variable: + if (PHIInVal == PHIUser) { + // Scalarize the binary operation. Its first operand is the + // scalar PHI, and the second operand is extracted from the other + // vector operand. + BinaryOperator *B0 = cast<BinaryOperator>(PHIUser); + unsigned opId = (B0->getOperand(0) == PN) ? 1 : 0; + Value *Op = InsertNewInstWith( + ExtractElementInst::Create(B0->getOperand(opId), Elt, + B0->getOperand(opId)->getName() + ".Elt"), + *B0); + Value *newPHIUser = InsertNewInstWith( + BinaryOperator::CreateWithCopiedFlags(B0->getOpcode(), + scalarPHI, Op, B0), *B0); + scalarPHI->addIncoming(newPHIUser, inBB); + } else { + // Scalarize PHI input: + Instruction *newEI = ExtractElementInst::Create(PHIInVal, Elt, ""); + // Insert the new instruction into the predecessor basic block. + Instruction *pos = dyn_cast<Instruction>(PHIInVal); + BasicBlock::iterator InsertPos; + if (pos && !isa<PHINode>(pos)) { + InsertPos = ++pos->getIterator(); + } else { + InsertPos = inBB->getFirstInsertionPt(); + } + + InsertNewInstWith(newEI, *InsertPos); + + scalarPHI->addIncoming(newEI, inBB); + } + } + + for (auto E : Extracts) + replaceInstUsesWith(*E, scalarPHI); + + return &EI; +} + +Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { + Value *X; + uint64_t ExtIndexC; + if (!match(Ext.getVectorOperand(), m_BitCast(m_Value(X))) || + !match(Ext.getIndexOperand(), m_ConstantInt(ExtIndexC))) + return nullptr; + + ElementCount NumElts = + cast<VectorType>(Ext.getVectorOperandType())->getElementCount(); + Type *DestTy = Ext.getType(); + bool IsBigEndian = DL.isBigEndian(); + + // If we are casting an integer to vector and extracting a portion, that is + // a shift-right and truncate. + // TODO: Allow FP dest type by casting the trunc to FP? + if (X->getType()->isIntegerTy() && DestTy->isIntegerTy() && + isDesirableIntType(X->getType()->getPrimitiveSizeInBits())) { + assert(isa<FixedVectorType>(Ext.getVectorOperand()->getType()) && + "Expected fixed vector type for bitcast from scalar integer"); + + // Big endian requires adjusting the extract index since MSB is at index 0. + // LittleEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 X to i8 + // BigEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 (X >> 24) to i8 + if (IsBigEndian) + ExtIndexC = NumElts.getKnownMinValue() - 1 - ExtIndexC; + unsigned ShiftAmountC = ExtIndexC * DestTy->getPrimitiveSizeInBits(); + if (!ShiftAmountC || Ext.getVectorOperand()->hasOneUse()) { + Value *Lshr = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset"); + return new TruncInst(Lshr, DestTy); + } + } + + if (!X->getType()->isVectorTy()) + return nullptr; + + // If this extractelement is using a bitcast from a vector of the same number + // of elements, see if we can find the source element from the source vector: + // extelt (bitcast VecX), IndexC --> bitcast X[IndexC] + auto *SrcTy = cast<VectorType>(X->getType()); + ElementCount NumSrcElts = SrcTy->getElementCount(); + if (NumSrcElts == NumElts) + if (Value *Elt = findScalarElement(X, ExtIndexC)) + return new BitCastInst(Elt, DestTy); + + assert(NumSrcElts.isScalable() == NumElts.isScalable() && + "Src and Dst must be the same sort of vector type"); + + // If the source elements are wider than the destination, try to shift and + // truncate a subset of scalar bits of an insert op. + if (NumSrcElts.getKnownMinValue() < NumElts.getKnownMinValue()) { + Value *Scalar; + uint64_t InsIndexC; + if (!match(X, m_InsertElt(m_Value(), m_Value(Scalar), + m_ConstantInt(InsIndexC)))) + return nullptr; + + // The extract must be from the subset of vector elements that we inserted + // into. Example: if we inserted element 1 of a <2 x i64> and we are + // extracting an i16 (narrowing ratio = 4), then this extract must be from 1 + // of elements 4-7 of the bitcasted vector. + unsigned NarrowingRatio = + NumElts.getKnownMinValue() / NumSrcElts.getKnownMinValue(); + if (ExtIndexC / NarrowingRatio != InsIndexC) + return nullptr; + + // We are extracting part of the original scalar. How that scalar is + // inserted into the vector depends on the endian-ness. Example: + // Vector Byte Elt Index: 0 1 2 3 4 5 6 7 + // +--+--+--+--+--+--+--+--+ + // inselt <2 x i32> V, <i32> S, 1: |V0|V1|V2|V3|S0|S1|S2|S3| + // extelt <4 x i16> V', 3: | |S2|S3| + // +--+--+--+--+--+--+--+--+ + // If this is little-endian, S2|S3 are the MSB of the 32-bit 'S' value. + // If this is big-endian, S2|S3 are the LSB of the 32-bit 'S' value. + // In this example, we must right-shift little-endian. Big-endian is just a + // truncate. + unsigned Chunk = ExtIndexC % NarrowingRatio; + if (IsBigEndian) + Chunk = NarrowingRatio - 1 - Chunk; + + // Bail out if this is an FP vector to FP vector sequence. That would take + // more instructions than we started with unless there is no shift, and it + // may not be handled as well in the backend. + bool NeedSrcBitcast = SrcTy->getScalarType()->isFloatingPointTy(); + bool NeedDestBitcast = DestTy->isFloatingPointTy(); + if (NeedSrcBitcast && NeedDestBitcast) + return nullptr; + + unsigned SrcWidth = SrcTy->getScalarSizeInBits(); + unsigned DestWidth = DestTy->getPrimitiveSizeInBits(); + unsigned ShAmt = Chunk * DestWidth; + + // TODO: This limitation is more strict than necessary. We could sum the + // number of new instructions and subtract the number eliminated to know if + // we can proceed. + if (!X->hasOneUse() || !Ext.getVectorOperand()->hasOneUse()) + if (NeedSrcBitcast || NeedDestBitcast) + return nullptr; + + if (NeedSrcBitcast) { + Type *SrcIntTy = IntegerType::getIntNTy(Scalar->getContext(), SrcWidth); + Scalar = Builder.CreateBitCast(Scalar, SrcIntTy); + } + + if (ShAmt) { + // Bail out if we could end with more instructions than we started with. + if (!Ext.getVectorOperand()->hasOneUse()) + return nullptr; + Scalar = Builder.CreateLShr(Scalar, ShAmt); + } + + if (NeedDestBitcast) { + Type *DestIntTy = IntegerType::getIntNTy(Scalar->getContext(), DestWidth); + return new BitCastInst(Builder.CreateTrunc(Scalar, DestIntTy), DestTy); + } + return new TruncInst(Scalar, DestTy); + } + + return nullptr; +} + +/// Find elements of V demanded by UserInstr. +static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { + unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); + + // Conservatively assume that all elements are needed. + APInt UsedElts(APInt::getAllOnes(VWidth)); + + switch (UserInstr->getOpcode()) { + case Instruction::ExtractElement: { + ExtractElementInst *EEI = cast<ExtractElementInst>(UserInstr); + assert(EEI->getVectorOperand() == V); + ConstantInt *EEIIndexC = dyn_cast<ConstantInt>(EEI->getIndexOperand()); + if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) { + UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue()); + } + break; + } + case Instruction::ShuffleVector: { + ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr); + unsigned MaskNumElts = + cast<FixedVectorType>(UserInstr->getType())->getNumElements(); + + UsedElts = APInt(VWidth, 0); + for (unsigned i = 0; i < MaskNumElts; i++) { + unsigned MaskVal = Shuffle->getMaskValue(i); + if (MaskVal == -1u || MaskVal >= 2 * VWidth) + continue; + if (Shuffle->getOperand(0) == V && (MaskVal < VWidth)) + UsedElts.setBit(MaskVal); + if (Shuffle->getOperand(1) == V && + ((MaskVal >= VWidth) && (MaskVal < 2 * VWidth))) + UsedElts.setBit(MaskVal - VWidth); + } + break; + } + default: + break; + } + return UsedElts; +} + +/// Find union of elements of V demanded by all its users. +/// If it is known by querying findDemandedEltsBySingleUser that +/// no user demands an element of V, then the corresponding bit +/// remains unset in the returned value. +static APInt findDemandedEltsByAllUsers(Value *V) { + unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); + + APInt UnionUsedElts(VWidth, 0); + for (const Use &U : V->uses()) { + if (Instruction *I = dyn_cast<Instruction>(U.getUser())) { + UnionUsedElts |= findDemandedEltsBySingleUser(V, I); + } else { + UnionUsedElts = APInt::getAllOnes(VWidth); + break; + } + + if (UnionUsedElts.isAllOnes()) + break; + } + + return UnionUsedElts; +} + +/// Given a constant index for a extractelement or insertelement instruction, +/// return it with the canonical type if it isn't already canonical. We +/// arbitrarily pick 64 bit as our canonical type. The actual bitwidth doesn't +/// matter, we just want a consistent type to simplify CSE. +ConstantInt *getPreferredVectorIndex(ConstantInt *IndexC) { + const unsigned IndexBW = IndexC->getType()->getBitWidth(); + if (IndexBW == 64 || IndexC->getValue().getActiveBits() > 64) + return nullptr; + return ConstantInt::get(IndexC->getContext(), + IndexC->getValue().zextOrTrunc(64)); +} + +Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { + Value *SrcVec = EI.getVectorOperand(); + Value *Index = EI.getIndexOperand(); + if (Value *V = simplifyExtractElementInst(SrcVec, Index, + SQ.getWithInstruction(&EI))) + return replaceInstUsesWith(EI, V); + + // If extracting a specified index from the vector, see if we can recursively + // find a previously computed scalar that was inserted into the vector. + auto *IndexC = dyn_cast<ConstantInt>(Index); + if (IndexC) { + // Canonicalize type of constant indices to i64 to simplify CSE + if (auto *NewIdx = getPreferredVectorIndex(IndexC)) + return replaceOperand(EI, 1, NewIdx); + + ElementCount EC = EI.getVectorOperandType()->getElementCount(); + unsigned NumElts = EC.getKnownMinValue(); + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(SrcVec)) { + Intrinsic::ID IID = II->getIntrinsicID(); + // Index needs to be lower than the minimum size of the vector, because + // for scalable vector, the vector size is known at run time. + if (IID == Intrinsic::experimental_stepvector && + IndexC->getValue().ult(NumElts)) { + Type *Ty = EI.getType(); + unsigned BitWidth = Ty->getIntegerBitWidth(); + Value *Idx; + // Return index when its value does not exceed the allowed limit + // for the element type of the vector, otherwise return undefined. + if (IndexC->getValue().getActiveBits() <= BitWidth) + Idx = ConstantInt::get(Ty, IndexC->getValue().zextOrTrunc(BitWidth)); + else + Idx = UndefValue::get(Ty); + return replaceInstUsesWith(EI, Idx); + } + } + + // InstSimplify should handle cases where the index is invalid. + // For fixed-length vector, it's invalid to extract out-of-range element. + if (!EC.isScalable() && IndexC->getValue().uge(NumElts)) + return nullptr; + + if (Instruction *I = foldBitcastExtElt(EI)) + return I; + + // If there's a vector PHI feeding a scalar use through this extractelement + // instruction, try to scalarize the PHI. + if (auto *Phi = dyn_cast<PHINode>(SrcVec)) + if (Instruction *ScalarPHI = scalarizePHI(EI, Phi)) + return ScalarPHI; + } + + // TODO come up with a n-ary matcher that subsumes both unary and + // binary matchers. + UnaryOperator *UO; + if (match(SrcVec, m_UnOp(UO)) && cheapToScalarize(SrcVec, Index)) { + // extelt (unop X), Index --> unop (extelt X, Index) + Value *X = UO->getOperand(0); + Value *E = Builder.CreateExtractElement(X, Index); + return UnaryOperator::CreateWithCopiedFlags(UO->getOpcode(), E, UO); + } + + BinaryOperator *BO; + if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, Index)) { + // extelt (binop X, Y), Index --> binop (extelt X, Index), (extelt Y, Index) + Value *X = BO->getOperand(0), *Y = BO->getOperand(1); + Value *E0 = Builder.CreateExtractElement(X, Index); + Value *E1 = Builder.CreateExtractElement(Y, Index); + return BinaryOperator::CreateWithCopiedFlags(BO->getOpcode(), E0, E1, BO); + } + + Value *X, *Y; + CmpInst::Predicate Pred; + if (match(SrcVec, m_Cmp(Pred, m_Value(X), m_Value(Y))) && + cheapToScalarize(SrcVec, Index)) { + // extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index) + Value *E0 = Builder.CreateExtractElement(X, Index); + Value *E1 = Builder.CreateExtractElement(Y, Index); + return CmpInst::Create(cast<CmpInst>(SrcVec)->getOpcode(), Pred, E0, E1); + } + + if (auto *I = dyn_cast<Instruction>(SrcVec)) { + if (auto *IE = dyn_cast<InsertElementInst>(I)) { + // instsimplify already handled the case where the indices are constants + // and equal by value, if both are constants, they must not be the same + // value, extract from the pre-inserted value instead. + if (isa<Constant>(IE->getOperand(2)) && IndexC) + return replaceOperand(EI, 0, IE->getOperand(0)); + } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { + auto *VecType = cast<VectorType>(GEP->getType()); + ElementCount EC = VecType->getElementCount(); + uint64_t IdxVal = IndexC ? IndexC->getZExtValue() : 0; + if (IndexC && IdxVal < EC.getKnownMinValue() && GEP->hasOneUse()) { + // Find out why we have a vector result - these are a few examples: + // 1. We have a scalar pointer and a vector of indices, or + // 2. We have a vector of pointers and a scalar index, or + // 3. We have a vector of pointers and a vector of indices, etc. + // Here we only consider combining when there is exactly one vector + // operand, since the optimization is less obviously a win due to + // needing more than one extractelements. + + unsigned VectorOps = + llvm::count_if(GEP->operands(), [](const Value *V) { + return isa<VectorType>(V->getType()); + }); + if (VectorOps == 1) { + Value *NewPtr = GEP->getPointerOperand(); + if (isa<VectorType>(NewPtr->getType())) + NewPtr = Builder.CreateExtractElement(NewPtr, IndexC); + + SmallVector<Value *> NewOps; + for (unsigned I = 1; I != GEP->getNumOperands(); ++I) { + Value *Op = GEP->getOperand(I); + if (isa<VectorType>(Op->getType())) + NewOps.push_back(Builder.CreateExtractElement(Op, IndexC)); + else + NewOps.push_back(Op); + } + + GetElementPtrInst *NewGEP = GetElementPtrInst::Create( + GEP->getSourceElementType(), NewPtr, NewOps); + NewGEP->setIsInBounds(GEP->isInBounds()); + return NewGEP; + } + } + } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) { + // If this is extracting an element from a shufflevector, figure out where + // it came from and extract from the appropriate input element instead. + // Restrict the following transformation to fixed-length vector. + if (isa<FixedVectorType>(SVI->getType()) && isa<ConstantInt>(Index)) { + int SrcIdx = + SVI->getMaskValue(cast<ConstantInt>(Index)->getZExtValue()); + Value *Src; + unsigned LHSWidth = cast<FixedVectorType>(SVI->getOperand(0)->getType()) + ->getNumElements(); + + if (SrcIdx < 0) + return replaceInstUsesWith(EI, UndefValue::get(EI.getType())); + if (SrcIdx < (int)LHSWidth) + Src = SVI->getOperand(0); + else { + SrcIdx -= LHSWidth; + Src = SVI->getOperand(1); + } + Type *Int32Ty = Type::getInt32Ty(EI.getContext()); + return ExtractElementInst::Create( + Src, ConstantInt::get(Int32Ty, SrcIdx, false)); + } + } else if (auto *CI = dyn_cast<CastInst>(I)) { + // Canonicalize extractelement(cast) -> cast(extractelement). + // Bitcasts can change the number of vector elements, and they cost + // nothing. + if (CI->hasOneUse() && (CI->getOpcode() != Instruction::BitCast)) { + Value *EE = Builder.CreateExtractElement(CI->getOperand(0), Index); + return CastInst::Create(CI->getOpcode(), EE, EI.getType()); + } + } + } + + // Run demanded elements after other transforms as this can drop flags on + // binops. If there's two paths to the same final result, we prefer the + // one which doesn't force us to drop flags. + if (IndexC) { + ElementCount EC = EI.getVectorOperandType()->getElementCount(); + unsigned NumElts = EC.getKnownMinValue(); + // This instruction only demands the single element from the input vector. + // Skip for scalable type, the number of elements is unknown at + // compile-time. + if (!EC.isScalable() && NumElts != 1) { + // If the input vector has a single use, simplify it based on this use + // property. + if (SrcVec->hasOneUse()) { + APInt UndefElts(NumElts, 0); + APInt DemandedElts(NumElts, 0); + DemandedElts.setBit(IndexC->getZExtValue()); + if (Value *V = + SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) + return replaceOperand(EI, 0, V); + } else { + // If the input vector has multiple uses, simplify it based on a union + // of all elements used. + APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec); + if (!DemandedElts.isAllOnes()) { + APInt UndefElts(NumElts, 0); + if (Value *V = SimplifyDemandedVectorElts( + SrcVec, DemandedElts, UndefElts, 0 /* Depth */, + true /* AllowMultipleUsers */)) { + if (V != SrcVec) { + SrcVec->replaceAllUsesWith(V); + return &EI; + } + } + } + } + } + } + return nullptr; +} + +/// If V is a shuffle of values that ONLY returns elements from either LHS or +/// RHS, return the shuffle mask and true. Otherwise, return false. +static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, + SmallVectorImpl<int> &Mask) { + assert(LHS->getType() == RHS->getType() && + "Invalid CollectSingleShuffleElements"); + unsigned NumElts = cast<FixedVectorType>(V->getType())->getNumElements(); + + if (match(V, m_Undef())) { + Mask.assign(NumElts, -1); + return true; + } + + if (V == LHS) { + for (unsigned i = 0; i != NumElts; ++i) + Mask.push_back(i); + return true; + } + + if (V == RHS) { + for (unsigned i = 0; i != NumElts; ++i) + Mask.push_back(i + NumElts); + return true; + } + + if (InsertElementInst *IEI = dyn_cast<InsertElementInst>(V)) { + // If this is an insert of an extract from some other vector, include it. + Value *VecOp = IEI->getOperand(0); + Value *ScalarOp = IEI->getOperand(1); + Value *IdxOp = IEI->getOperand(2); + + if (!isa<ConstantInt>(IdxOp)) + return false; + unsigned InsertedIdx = cast<ConstantInt>(IdxOp)->getZExtValue(); + + if (isa<UndefValue>(ScalarOp)) { // inserting undef into vector. + // We can handle this if the vector we are inserting into is + // transitively ok. + if (collectSingleShuffleElements(VecOp, LHS, RHS, Mask)) { + // If so, update the mask to reflect the inserted undef. + Mask[InsertedIdx] = -1; + return true; + } + } else if (ExtractElementInst *EI = dyn_cast<ExtractElementInst>(ScalarOp)){ + if (isa<ConstantInt>(EI->getOperand(1))) { + unsigned ExtractedIdx = + cast<ConstantInt>(EI->getOperand(1))->getZExtValue(); + unsigned NumLHSElts = + cast<FixedVectorType>(LHS->getType())->getNumElements(); + + // This must be extracting from either LHS or RHS. + if (EI->getOperand(0) == LHS || EI->getOperand(0) == RHS) { + // We can handle this if the vector we are inserting into is + // transitively ok. + if (collectSingleShuffleElements(VecOp, LHS, RHS, Mask)) { + // If so, update the mask to reflect the inserted value. + if (EI->getOperand(0) == LHS) { + Mask[InsertedIdx % NumElts] = ExtractedIdx; + } else { + assert(EI->getOperand(0) == RHS); + Mask[InsertedIdx % NumElts] = ExtractedIdx + NumLHSElts; + } + return true; + } + } + } + } + } + + return false; +} + +/// If we have insertion into a vector that is wider than the vector that we +/// are extracting from, try to widen the source vector to allow a single +/// shufflevector to replace one or more insert/extract pairs. +static void replaceExtractElements(InsertElementInst *InsElt, + ExtractElementInst *ExtElt, + InstCombinerImpl &IC) { + auto *InsVecType = cast<FixedVectorType>(InsElt->getType()); + auto *ExtVecType = cast<FixedVectorType>(ExtElt->getVectorOperandType()); + unsigned NumInsElts = InsVecType->getNumElements(); + unsigned NumExtElts = ExtVecType->getNumElements(); + + // The inserted-to vector must be wider than the extracted-from vector. + if (InsVecType->getElementType() != ExtVecType->getElementType() || + NumExtElts >= NumInsElts) + return; + + // Create a shuffle mask to widen the extended-from vector using poison + // values. The mask selects all of the values of the original vector followed + // by as many poison values as needed to create a vector of the same length + // as the inserted-to vector. + SmallVector<int, 16> ExtendMask; + for (unsigned i = 0; i < NumExtElts; ++i) + ExtendMask.push_back(i); + for (unsigned i = NumExtElts; i < NumInsElts; ++i) + ExtendMask.push_back(-1); + + Value *ExtVecOp = ExtElt->getVectorOperand(); + auto *ExtVecOpInst = dyn_cast<Instruction>(ExtVecOp); + BasicBlock *InsertionBlock = (ExtVecOpInst && !isa<PHINode>(ExtVecOpInst)) + ? ExtVecOpInst->getParent() + : ExtElt->getParent(); + + // TODO: This restriction matches the basic block check below when creating + // new extractelement instructions. If that limitation is removed, this one + // could also be removed. But for now, we just bail out to ensure that we + // will replace the extractelement instruction that is feeding our + // insertelement instruction. This allows the insertelement to then be + // replaced by a shufflevector. If the insertelement is not replaced, we can + // induce infinite looping because there's an optimization for extractelement + // that will delete our widening shuffle. This would trigger another attempt + // here to create that shuffle, and we spin forever. + if (InsertionBlock != InsElt->getParent()) + return; + + // TODO: This restriction matches the check in visitInsertElementInst() and + // prevents an infinite loop caused by not turning the extract/insert pair + // into a shuffle. We really should not need either check, but we're lacking + // folds for shufflevectors because we're afraid to generate shuffle masks + // that the backend can't handle. + if (InsElt->hasOneUse() && isa<InsertElementInst>(InsElt->user_back())) + return; + + auto *WideVec = new ShuffleVectorInst(ExtVecOp, ExtendMask); + + // Insert the new shuffle after the vector operand of the extract is defined + // (as long as it's not a PHI) or at the start of the basic block of the + // extract, so any subsequent extracts in the same basic block can use it. + // TODO: Insert before the earliest ExtractElementInst that is replaced. + if (ExtVecOpInst && !isa<PHINode>(ExtVecOpInst)) + WideVec->insertAfter(ExtVecOpInst); + else + IC.InsertNewInstWith(WideVec, *ExtElt->getParent()->getFirstInsertionPt()); + + // Replace extracts from the original narrow vector with extracts from the new + // wide vector. + for (User *U : ExtVecOp->users()) { + ExtractElementInst *OldExt = dyn_cast<ExtractElementInst>(U); + if (!OldExt || OldExt->getParent() != WideVec->getParent()) + continue; + auto *NewExt = ExtractElementInst::Create(WideVec, OldExt->getOperand(1)); + NewExt->insertAfter(OldExt); + IC.replaceInstUsesWith(*OldExt, NewExt); + } +} + +/// We are building a shuffle to create V, which is a sequence of insertelement, +/// extractelement pairs. If PermittedRHS is set, then we must either use it or +/// not rely on the second vector source. Return a std::pair containing the +/// left and right vectors of the proposed shuffle (or 0), and set the Mask +/// parameter as required. +/// +/// Note: we intentionally don't try to fold earlier shuffles since they have +/// often been chosen carefully to be efficiently implementable on the target. +using ShuffleOps = std::pair<Value *, Value *>; + +static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask, + Value *PermittedRHS, + InstCombinerImpl &IC) { + assert(V->getType()->isVectorTy() && "Invalid shuffle!"); + unsigned NumElts = cast<FixedVectorType>(V->getType())->getNumElements(); + + if (match(V, m_Undef())) { + Mask.assign(NumElts, -1); + return std::make_pair( + PermittedRHS ? UndefValue::get(PermittedRHS->getType()) : V, nullptr); + } + + if (isa<ConstantAggregateZero>(V)) { + Mask.assign(NumElts, 0); + return std::make_pair(V, nullptr); + } + + if (InsertElementInst *IEI = dyn_cast<InsertElementInst>(V)) { + // If this is an insert of an extract from some other vector, include it. + Value *VecOp = IEI->getOperand(0); + Value *ScalarOp = IEI->getOperand(1); + Value *IdxOp = IEI->getOperand(2); + + if (ExtractElementInst *EI = dyn_cast<ExtractElementInst>(ScalarOp)) { + if (isa<ConstantInt>(EI->getOperand(1)) && isa<ConstantInt>(IdxOp)) { + unsigned ExtractedIdx = + cast<ConstantInt>(EI->getOperand(1))->getZExtValue(); + unsigned InsertedIdx = cast<ConstantInt>(IdxOp)->getZExtValue(); + + // Either the extracted from or inserted into vector must be RHSVec, + // otherwise we'd end up with a shuffle of three inputs. + if (EI->getOperand(0) == PermittedRHS || PermittedRHS == nullptr) { + Value *RHS = EI->getOperand(0); + ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS, IC); + assert(LR.second == nullptr || LR.second == RHS); + + if (LR.first->getType() != RHS->getType()) { + // Although we are giving up for now, see if we can create extracts + // that match the inserts for another round of combining. + replaceExtractElements(IEI, EI, IC); + + // We tried our best, but we can't find anything compatible with RHS + // further up the chain. Return a trivial shuffle. + for (unsigned i = 0; i < NumElts; ++i) + Mask[i] = i; + return std::make_pair(V, nullptr); + } + + unsigned NumLHSElts = + cast<FixedVectorType>(RHS->getType())->getNumElements(); + Mask[InsertedIdx % NumElts] = NumLHSElts + ExtractedIdx; + return std::make_pair(LR.first, RHS); + } + + if (VecOp == PermittedRHS) { + // We've gone as far as we can: anything on the other side of the + // extractelement will already have been converted into a shuffle. + unsigned NumLHSElts = + cast<FixedVectorType>(EI->getOperand(0)->getType()) + ->getNumElements(); + for (unsigned i = 0; i != NumElts; ++i) + Mask.push_back(i == InsertedIdx ? ExtractedIdx : NumLHSElts + i); + return std::make_pair(EI->getOperand(0), PermittedRHS); + } + + // If this insertelement is a chain that comes from exactly these two + // vectors, return the vector and the effective shuffle. + if (EI->getOperand(0)->getType() == PermittedRHS->getType() && + collectSingleShuffleElements(IEI, EI->getOperand(0), PermittedRHS, + Mask)) + return std::make_pair(EI->getOperand(0), PermittedRHS); + } + } + } + + // Otherwise, we can't do anything fancy. Return an identity vector. + for (unsigned i = 0; i != NumElts; ++i) + Mask.push_back(i); + return std::make_pair(V, nullptr); +} + +/// Look for chain of insertvalue's that fully define an aggregate, and trace +/// back the values inserted, see if they are all were extractvalue'd from +/// the same source aggregate from the exact same element indexes. +/// If they were, just reuse the source aggregate. +/// This potentially deals with PHI indirections. +Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( + InsertValueInst &OrigIVI) { + Type *AggTy = OrigIVI.getType(); + unsigned NumAggElts; + switch (AggTy->getTypeID()) { + case Type::StructTyID: + NumAggElts = AggTy->getStructNumElements(); + break; + case Type::ArrayTyID: + NumAggElts = AggTy->getArrayNumElements(); + break; + default: + llvm_unreachable("Unhandled aggregate type?"); + } + + // Arbitrary aggregate size cut-off. Motivation for limit of 2 is to be able + // to handle clang C++ exception struct (which is hardcoded as {i8*, i32}), + // FIXME: any interesting patterns to be caught with larger limit? + assert(NumAggElts > 0 && "Aggregate should have elements."); + if (NumAggElts > 2) + return nullptr; + + static constexpr auto NotFound = None; + static constexpr auto FoundMismatch = nullptr; + + // Try to find a value of each element of an aggregate. + // FIXME: deal with more complex, not one-dimensional, aggregate types + SmallVector<Optional<Instruction *>, 2> AggElts(NumAggElts, NotFound); + + // Do we know values for each element of the aggregate? + auto KnowAllElts = [&AggElts]() { + return all_of(AggElts, + [](Optional<Instruction *> Elt) { return Elt != NotFound; }); + }; + + int Depth = 0; + + // Arbitrary `insertvalue` visitation depth limit. Let's be okay with + // every element being overwritten twice, which should never happen. + static const int DepthLimit = 2 * NumAggElts; + + // Recurse up the chain of `insertvalue` aggregate operands until either we've + // reconstructed full initializer or can't visit any more `insertvalue`'s. + for (InsertValueInst *CurrIVI = &OrigIVI; + Depth < DepthLimit && CurrIVI && !KnowAllElts(); + CurrIVI = dyn_cast<InsertValueInst>(CurrIVI->getAggregateOperand()), + ++Depth) { + auto *InsertedValue = + dyn_cast<Instruction>(CurrIVI->getInsertedValueOperand()); + if (!InsertedValue) + return nullptr; // Inserted value must be produced by an instruction. + + ArrayRef<unsigned int> Indices = CurrIVI->getIndices(); + + // Don't bother with more than single-level aggregates. + if (Indices.size() != 1) + return nullptr; // FIXME: deal with more complex aggregates? + + // Now, we may have already previously recorded the value for this element + // of an aggregate. If we did, that means the CurrIVI will later be + // overwritten with the already-recorded value. But if not, let's record it! + Optional<Instruction *> &Elt = AggElts[Indices.front()]; + Elt = Elt.value_or(InsertedValue); + + // FIXME: should we handle chain-terminating undef base operand? + } + + // Was that sufficient to deduce the full initializer for the aggregate? + if (!KnowAllElts()) + return nullptr; // Give up then. + + // We now want to find the source[s] of the aggregate elements we've found. + // And with "source" we mean the original aggregate[s] from which + // the inserted elements were extracted. This may require PHI translation. + + enum class AggregateDescription { + /// When analyzing the value that was inserted into an aggregate, we did + /// not manage to find defining `extractvalue` instruction to analyze. + NotFound, + /// When analyzing the value that was inserted into an aggregate, we did + /// manage to find defining `extractvalue` instruction[s], and everything + /// matched perfectly - aggregate type, element insertion/extraction index. + Found, + /// When analyzing the value that was inserted into an aggregate, we did + /// manage to find defining `extractvalue` instruction, but there was + /// a mismatch: either the source type from which the extraction was didn't + /// match the aggregate type into which the insertion was, + /// or the extraction/insertion channels mismatched, + /// or different elements had different source aggregates. + FoundMismatch + }; + auto Describe = [](Optional<Value *> SourceAggregate) { + if (SourceAggregate == NotFound) + return AggregateDescription::NotFound; + if (*SourceAggregate == FoundMismatch) + return AggregateDescription::FoundMismatch; + return AggregateDescription::Found; + }; + + // Given the value \p Elt that was being inserted into element \p EltIdx of an + // aggregate AggTy, see if \p Elt was originally defined by an + // appropriate extractvalue (same element index, same aggregate type). + // If found, return the source aggregate from which the extraction was. + // If \p PredBB is provided, does PHI translation of an \p Elt first. + auto FindSourceAggregate = + [&](Instruction *Elt, unsigned EltIdx, Optional<BasicBlock *> UseBB, + Optional<BasicBlock *> PredBB) -> Optional<Value *> { + // For now(?), only deal with, at most, a single level of PHI indirection. + if (UseBB && PredBB) + Elt = dyn_cast<Instruction>(Elt->DoPHITranslation(*UseBB, *PredBB)); + // FIXME: deal with multiple levels of PHI indirection? + + // Did we find an extraction? + auto *EVI = dyn_cast_or_null<ExtractValueInst>(Elt); + if (!EVI) + return NotFound; + + Value *SourceAggregate = EVI->getAggregateOperand(); + + // Is the extraction from the same type into which the insertion was? + if (SourceAggregate->getType() != AggTy) + return FoundMismatch; + // And the element index doesn't change between extraction and insertion? + if (EVI->getNumIndices() != 1 || EltIdx != EVI->getIndices().front()) + return FoundMismatch; + + return SourceAggregate; // AggregateDescription::Found + }; + + // Given elements AggElts that were constructing an aggregate OrigIVI, + // see if we can find appropriate source aggregate for each of the elements, + // and see it's the same aggregate for each element. If so, return it. + auto FindCommonSourceAggregate = + [&](Optional<BasicBlock *> UseBB, + Optional<BasicBlock *> PredBB) -> Optional<Value *> { + Optional<Value *> SourceAggregate; + + for (auto I : enumerate(AggElts)) { + assert(Describe(SourceAggregate) != AggregateDescription::FoundMismatch && + "We don't store nullptr in SourceAggregate!"); + assert((Describe(SourceAggregate) == AggregateDescription::Found) == + (I.index() != 0) && + "SourceAggregate should be valid after the first element,"); + + // For this element, is there a plausible source aggregate? + // FIXME: we could special-case undef element, IFF we know that in the + // source aggregate said element isn't poison. + Optional<Value *> SourceAggregateForElement = + FindSourceAggregate(*I.value(), I.index(), UseBB, PredBB); + + // Okay, what have we found? Does that correlate with previous findings? + + // Regardless of whether or not we have previously found source + // aggregate for previous elements (if any), if we didn't find one for + // this element, passthrough whatever we have just found. + if (Describe(SourceAggregateForElement) != AggregateDescription::Found) + return SourceAggregateForElement; + + // Okay, we have found source aggregate for this element. + // Let's see what we already know from previous elements, if any. + switch (Describe(SourceAggregate)) { + case AggregateDescription::NotFound: + // This is apparently the first element that we have examined. + SourceAggregate = SourceAggregateForElement; // Record the aggregate! + continue; // Great, now look at next element. + case AggregateDescription::Found: + // We have previously already successfully examined other elements. + // Is this the same source aggregate we've found for other elements? + if (*SourceAggregateForElement != *SourceAggregate) + return FoundMismatch; + continue; // Still the same aggregate, look at next element. + case AggregateDescription::FoundMismatch: + llvm_unreachable("Can't happen. We would have early-exited then."); + }; + } + + assert(Describe(SourceAggregate) == AggregateDescription::Found && + "Must be a valid Value"); + return *SourceAggregate; + }; + + Optional<Value *> SourceAggregate; + + // Can we find the source aggregate without looking at predecessors? + SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/None, /*PredBB=*/None); + if (Describe(SourceAggregate) != AggregateDescription::NotFound) { + if (Describe(SourceAggregate) == AggregateDescription::FoundMismatch) + return nullptr; // Conflicting source aggregates! + ++NumAggregateReconstructionsSimplified; + return replaceInstUsesWith(OrigIVI, *SourceAggregate); + } + + // Okay, apparently we need to look at predecessors. + + // We should be smart about picking the "use" basic block, which will be the + // merge point for aggregate, where we'll insert the final PHI that will be + // used instead of OrigIVI. Basic block of OrigIVI is *not* the right choice. + // We should look in which blocks each of the AggElts is being defined, + // they all should be defined in the same basic block. + BasicBlock *UseBB = nullptr; + + for (const Optional<Instruction *> &I : AggElts) { + BasicBlock *BB = (*I)->getParent(); + // If it's the first instruction we've encountered, record the basic block. + if (!UseBB) { + UseBB = BB; + continue; + } + // Otherwise, this must be the same basic block we've seen previously. + if (UseBB != BB) + return nullptr; + } + + // If *all* of the elements are basic-block-independent, meaning they are + // either function arguments, or constant expressions, then if we didn't + // handle them without predecessor-aware handling, we won't handle them now. + if (!UseBB) + return nullptr; + + // If we didn't manage to find source aggregate without looking at + // predecessors, and there are no predecessors to look at, then we're done. + if (pred_empty(UseBB)) + return nullptr; + + // Arbitrary predecessor count limit. + static const int PredCountLimit = 64; + + // Cache the (non-uniqified!) list of predecessors in a vector, + // checking the limit at the same time for efficiency. + SmallVector<BasicBlock *, 4> Preds; // May have duplicates! + for (BasicBlock *Pred : predecessors(UseBB)) { + // Don't bother if there are too many predecessors. + if (Preds.size() >= PredCountLimit) // FIXME: only count duplicates once? + return nullptr; + Preds.emplace_back(Pred); + } + + // For each predecessor, what is the source aggregate, + // from which all the elements were originally extracted from? + // Note that we want for the map to have stable iteration order! + SmallDenseMap<BasicBlock *, Value *, 4> SourceAggregates; + for (BasicBlock *Pred : Preds) { + std::pair<decltype(SourceAggregates)::iterator, bool> IV = + SourceAggregates.insert({Pred, nullptr}); + // Did we already evaluate this predecessor? + if (!IV.second) + continue; + + // Let's hope that when coming from predecessor Pred, all elements of the + // aggregate produced by OrigIVI must have been originally extracted from + // the same aggregate. Is that so? Can we find said original aggregate? + SourceAggregate = FindCommonSourceAggregate(UseBB, Pred); + if (Describe(SourceAggregate) != AggregateDescription::Found) + return nullptr; // Give up. + IV.first->second = *SourceAggregate; + } + + // All good! Now we just need to thread the source aggregates here. + // Note that we have to insert the new PHI here, ourselves, because we can't + // rely on InstCombinerImpl::run() inserting it into the right basic block. + // Note that the same block can be a predecessor more than once, + // and we need to preserve that invariant for the PHI node. + BuilderTy::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(UseBB->getFirstNonPHI()); + auto *PHI = + Builder.CreatePHI(AggTy, Preds.size(), OrigIVI.getName() + ".merged"); + for (BasicBlock *Pred : Preds) + PHI->addIncoming(SourceAggregates[Pred], Pred); + + ++NumAggregateReconstructionsSimplified; + return replaceInstUsesWith(OrigIVI, PHI); +} + +/// Try to find redundant insertvalue instructions, like the following ones: +/// %0 = insertvalue { i8, i32 } undef, i8 %x, 0 +/// %1 = insertvalue { i8, i32 } %0, i8 %y, 0 +/// Here the second instruction inserts values at the same indices, as the +/// first one, making the first one redundant. +/// It should be transformed to: +/// %0 = insertvalue { i8, i32 } undef, i8 %y, 0 +Instruction *InstCombinerImpl::visitInsertValueInst(InsertValueInst &I) { + bool IsRedundant = false; + ArrayRef<unsigned int> FirstIndices = I.getIndices(); + + // If there is a chain of insertvalue instructions (each of them except the + // last one has only one use and it's another insertvalue insn from this + // chain), check if any of the 'children' uses the same indices as the first + // instruction. In this case, the first one is redundant. + Value *V = &I; + unsigned Depth = 0; + while (V->hasOneUse() && Depth < 10) { + User *U = V->user_back(); + auto UserInsInst = dyn_cast<InsertValueInst>(U); + if (!UserInsInst || U->getOperand(0) != V) + break; + if (UserInsInst->getIndices() == FirstIndices) { + IsRedundant = true; + break; + } + V = UserInsInst; + Depth++; + } + + if (IsRedundant) + return replaceInstUsesWith(I, I.getOperand(0)); + + if (Instruction *NewI = foldAggregateConstructionIntoAggregateReuse(I)) + return NewI; + + return nullptr; +} + +static bool isShuffleEquivalentToSelect(ShuffleVectorInst &Shuf) { + // Can not analyze scalable type, the number of elements is not a compile-time + // constant. + if (isa<ScalableVectorType>(Shuf.getOperand(0)->getType())) + return false; + + int MaskSize = Shuf.getShuffleMask().size(); + int VecSize = + cast<FixedVectorType>(Shuf.getOperand(0)->getType())->getNumElements(); + + // A vector select does not change the size of the operands. + if (MaskSize != VecSize) + return false; + + // Each mask element must be undefined or choose a vector element from one of + // the source operands without crossing vector lanes. + for (int i = 0; i != MaskSize; ++i) { + int Elt = Shuf.getMaskValue(i); + if (Elt != -1 && Elt != i && Elt != i + VecSize) + return false; + } + + return true; +} + +/// Turn a chain of inserts that splats a value into an insert + shuffle: +/// insertelt(insertelt(insertelt(insertelt X, %k, 0), %k, 1), %k, 2) ... -> +/// shufflevector(insertelt(X, %k, 0), poison, zero) +static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) { + // We are interested in the last insert in a chain. So if this insert has a + // single user and that user is an insert, bail. + if (InsElt.hasOneUse() && isa<InsertElementInst>(InsElt.user_back())) + return nullptr; + + VectorType *VecTy = InsElt.getType(); + // Can not handle scalable type, the number of elements is not a compile-time + // constant. + if (isa<ScalableVectorType>(VecTy)) + return nullptr; + unsigned NumElements = cast<FixedVectorType>(VecTy)->getNumElements(); + + // Do not try to do this for a one-element vector, since that's a nop, + // and will cause an inf-loop. + if (NumElements == 1) + return nullptr; + + Value *SplatVal = InsElt.getOperand(1); + InsertElementInst *CurrIE = &InsElt; + SmallBitVector ElementPresent(NumElements, false); + InsertElementInst *FirstIE = nullptr; + + // Walk the chain backwards, keeping track of which indices we inserted into, + // until we hit something that isn't an insert of the splatted value. + while (CurrIE) { + auto *Idx = dyn_cast<ConstantInt>(CurrIE->getOperand(2)); + if (!Idx || CurrIE->getOperand(1) != SplatVal) + return nullptr; + + auto *NextIE = dyn_cast<InsertElementInst>(CurrIE->getOperand(0)); + // Check none of the intermediate steps have any additional uses, except + // for the root insertelement instruction, which can be re-used, if it + // inserts at position 0. + if (CurrIE != &InsElt && + (!CurrIE->hasOneUse() && (NextIE != nullptr || !Idx->isZero()))) + return nullptr; + + ElementPresent[Idx->getZExtValue()] = true; + FirstIE = CurrIE; + CurrIE = NextIE; + } + + // If this is just a single insertelement (not a sequence), we are done. + if (FirstIE == &InsElt) + return nullptr; + + // If we are not inserting into an undef vector, make sure we've seen an + // insert into every element. + // TODO: If the base vector is not undef, it might be better to create a splat + // and then a select-shuffle (blend) with the base vector. + if (!match(FirstIE->getOperand(0), m_Undef())) + if (!ElementPresent.all()) + return nullptr; + + // Create the insert + shuffle. + Type *Int32Ty = Type::getInt32Ty(InsElt.getContext()); + PoisonValue *PoisonVec = PoisonValue::get(VecTy); + Constant *Zero = ConstantInt::get(Int32Ty, 0); + if (!cast<ConstantInt>(FirstIE->getOperand(2))->isZero()) + FirstIE = InsertElementInst::Create(PoisonVec, SplatVal, Zero, "", &InsElt); + + // Splat from element 0, but replace absent elements with undef in the mask. + SmallVector<int, 16> Mask(NumElements, 0); + for (unsigned i = 0; i != NumElements; ++i) + if (!ElementPresent[i]) + Mask[i] = -1; + + return new ShuffleVectorInst(FirstIE, Mask); +} + +/// Try to fold an insert element into an existing splat shuffle by changing +/// the shuffle's mask to include the index of this insert element. +static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) { + // Check if the vector operand of this insert is a canonical splat shuffle. + auto *Shuf = dyn_cast<ShuffleVectorInst>(InsElt.getOperand(0)); + if (!Shuf || !Shuf->isZeroEltSplat()) + return nullptr; + + // Bail out early if shuffle is scalable type. The number of elements in + // shuffle mask is unknown at compile-time. + if (isa<ScalableVectorType>(Shuf->getType())) + return nullptr; + + // Check for a constant insertion index. + uint64_t IdxC; + if (!match(InsElt.getOperand(2), m_ConstantInt(IdxC))) + return nullptr; + + // Check if the splat shuffle's input is the same as this insert's scalar op. + Value *X = InsElt.getOperand(1); + Value *Op0 = Shuf->getOperand(0); + if (!match(Op0, m_InsertElt(m_Undef(), m_Specific(X), m_ZeroInt()))) + return nullptr; + + // Replace the shuffle mask element at the index of this insert with a zero. + // For example: + // inselt (shuf (inselt undef, X, 0), _, <0,undef,0,undef>), X, 1 + // --> shuf (inselt undef, X, 0), poison, <0,0,0,undef> + unsigned NumMaskElts = + cast<FixedVectorType>(Shuf->getType())->getNumElements(); + SmallVector<int, 16> NewMask(NumMaskElts); + for (unsigned i = 0; i != NumMaskElts; ++i) + NewMask[i] = i == IdxC ? 0 : Shuf->getMaskValue(i); + + return new ShuffleVectorInst(Op0, NewMask); +} + +/// Try to fold an extract+insert element into an existing identity shuffle by +/// changing the shuffle's mask to include the index of this insert element. +static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) { + // Check if the vector operand of this insert is an identity shuffle. + auto *Shuf = dyn_cast<ShuffleVectorInst>(InsElt.getOperand(0)); + if (!Shuf || !match(Shuf->getOperand(1), m_Undef()) || + !(Shuf->isIdentityWithExtract() || Shuf->isIdentityWithPadding())) + return nullptr; + + // Bail out early if shuffle is scalable type. The number of elements in + // shuffle mask is unknown at compile-time. + if (isa<ScalableVectorType>(Shuf->getType())) + return nullptr; + + // Check for a constant insertion index. + uint64_t IdxC; + if (!match(InsElt.getOperand(2), m_ConstantInt(IdxC))) + return nullptr; + + // Check if this insert's scalar op is extracted from the identity shuffle's + // input vector. + Value *Scalar = InsElt.getOperand(1); + Value *X = Shuf->getOperand(0); + if (!match(Scalar, m_ExtractElt(m_Specific(X), m_SpecificInt(IdxC)))) + return nullptr; + + // Replace the shuffle mask element at the index of this extract+insert with + // that same index value. + // For example: + // inselt (shuf X, IdMask), (extelt X, IdxC), IdxC --> shuf X, IdMask' + unsigned NumMaskElts = + cast<FixedVectorType>(Shuf->getType())->getNumElements(); + SmallVector<int, 16> NewMask(NumMaskElts); + ArrayRef<int> OldMask = Shuf->getShuffleMask(); + for (unsigned i = 0; i != NumMaskElts; ++i) { + if (i != IdxC) { + // All mask elements besides the inserted element remain the same. + NewMask[i] = OldMask[i]; + } else if (OldMask[i] == (int)IdxC) { + // If the mask element was already set, there's nothing to do + // (demanded elements analysis may unset it later). + return nullptr; + } else { + assert(OldMask[i] == UndefMaskElem && + "Unexpected shuffle mask element for identity shuffle"); + NewMask[i] = IdxC; + } + } + + return new ShuffleVectorInst(X, Shuf->getOperand(1), NewMask); +} + +/// If we have an insertelement instruction feeding into another insertelement +/// and the 2nd is inserting a constant into the vector, canonicalize that +/// constant insertion before the insertion of a variable: +/// +/// insertelement (insertelement X, Y, IdxC1), ScalarC, IdxC2 --> +/// insertelement (insertelement X, ScalarC, IdxC2), Y, IdxC1 +/// +/// This has the potential of eliminating the 2nd insertelement instruction +/// via constant folding of the scalar constant into a vector constant. +static Instruction *hoistInsEltConst(InsertElementInst &InsElt2, + InstCombiner::BuilderTy &Builder) { + auto *InsElt1 = dyn_cast<InsertElementInst>(InsElt2.getOperand(0)); + if (!InsElt1 || !InsElt1->hasOneUse()) + return nullptr; + + Value *X, *Y; + Constant *ScalarC; + ConstantInt *IdxC1, *IdxC2; + if (match(InsElt1->getOperand(0), m_Value(X)) && + match(InsElt1->getOperand(1), m_Value(Y)) && !isa<Constant>(Y) && + match(InsElt1->getOperand(2), m_ConstantInt(IdxC1)) && + match(InsElt2.getOperand(1), m_Constant(ScalarC)) && + match(InsElt2.getOperand(2), m_ConstantInt(IdxC2)) && IdxC1 != IdxC2) { + Value *NewInsElt1 = Builder.CreateInsertElement(X, ScalarC, IdxC2); + return InsertElementInst::Create(NewInsElt1, Y, IdxC1); + } + + return nullptr; +} + +/// insertelt (shufflevector X, CVec, Mask|insertelt X, C1, CIndex1), C, CIndex +/// --> shufflevector X, CVec', Mask' +static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { + auto *Inst = dyn_cast<Instruction>(InsElt.getOperand(0)); + // Bail out if the parent has more than one use. In that case, we'd be + // replacing the insertelt with a shuffle, and that's not a clear win. + if (!Inst || !Inst->hasOneUse()) + return nullptr; + if (auto *Shuf = dyn_cast<ShuffleVectorInst>(InsElt.getOperand(0))) { + // The shuffle must have a constant vector operand. The insertelt must have + // a constant scalar being inserted at a constant position in the vector. + Constant *ShufConstVec, *InsEltScalar; + uint64_t InsEltIndex; + if (!match(Shuf->getOperand(1), m_Constant(ShufConstVec)) || + !match(InsElt.getOperand(1), m_Constant(InsEltScalar)) || + !match(InsElt.getOperand(2), m_ConstantInt(InsEltIndex))) + return nullptr; + + // Adding an element to an arbitrary shuffle could be expensive, but a + // shuffle that selects elements from vectors without crossing lanes is + // assumed cheap. + // If we're just adding a constant into that shuffle, it will still be + // cheap. + if (!isShuffleEquivalentToSelect(*Shuf)) + return nullptr; + + // From the above 'select' check, we know that the mask has the same number + // of elements as the vector input operands. We also know that each constant + // input element is used in its lane and can not be used more than once by + // the shuffle. Therefore, replace the constant in the shuffle's constant + // vector with the insertelt constant. Replace the constant in the shuffle's + // mask vector with the insertelt index plus the length of the vector + // (because the constant vector operand of a shuffle is always the 2nd + // operand). + ArrayRef<int> Mask = Shuf->getShuffleMask(); + unsigned NumElts = Mask.size(); + SmallVector<Constant *, 16> NewShufElts(NumElts); + SmallVector<int, 16> NewMaskElts(NumElts); + for (unsigned I = 0; I != NumElts; ++I) { + if (I == InsEltIndex) { + NewShufElts[I] = InsEltScalar; + NewMaskElts[I] = InsEltIndex + NumElts; + } else { + // Copy over the existing values. + NewShufElts[I] = ShufConstVec->getAggregateElement(I); + NewMaskElts[I] = Mask[I]; + } + + // Bail if we failed to find an element. + if (!NewShufElts[I]) + return nullptr; + } + + // Create new operands for a shuffle that includes the constant of the + // original insertelt. The old shuffle will be dead now. + return new ShuffleVectorInst(Shuf->getOperand(0), + ConstantVector::get(NewShufElts), NewMaskElts); + } else if (auto *IEI = dyn_cast<InsertElementInst>(Inst)) { + // Transform sequences of insertelements ops with constant data/indexes into + // a single shuffle op. + // Can not handle scalable type, the number of elements needed to create + // shuffle mask is not a compile-time constant. + if (isa<ScalableVectorType>(InsElt.getType())) + return nullptr; + unsigned NumElts = + cast<FixedVectorType>(InsElt.getType())->getNumElements(); + + uint64_t InsertIdx[2]; + Constant *Val[2]; + if (!match(InsElt.getOperand(2), m_ConstantInt(InsertIdx[0])) || + !match(InsElt.getOperand(1), m_Constant(Val[0])) || + !match(IEI->getOperand(2), m_ConstantInt(InsertIdx[1])) || + !match(IEI->getOperand(1), m_Constant(Val[1]))) + return nullptr; + SmallVector<Constant *, 16> Values(NumElts); + SmallVector<int, 16> Mask(NumElts); + auto ValI = std::begin(Val); + // Generate new constant vector and mask. + // We have 2 values/masks from the insertelements instructions. Insert them + // into new value/mask vectors. + for (uint64_t I : InsertIdx) { + if (!Values[I]) { + Values[I] = *ValI; + Mask[I] = NumElts + I; + } + ++ValI; + } + // Remaining values are filled with 'undef' values. + for (unsigned I = 0; I < NumElts; ++I) { + if (!Values[I]) { + Values[I] = UndefValue::get(InsElt.getType()->getElementType()); + Mask[I] = I; + } + } + // Create new operands for a shuffle that includes the constant of the + // original insertelt. + return new ShuffleVectorInst(IEI->getOperand(0), + ConstantVector::get(Values), Mask); + } + return nullptr; +} + +/// If both the base vector and the inserted element are extended from the same +/// type, do the insert element in the narrow source type followed by extend. +/// TODO: This can be extended to include other cast opcodes, but particularly +/// if we create a wider insertelement, make sure codegen is not harmed. +static Instruction *narrowInsElt(InsertElementInst &InsElt, + InstCombiner::BuilderTy &Builder) { + // We are creating a vector extend. If the original vector extend has another + // use, that would mean we end up with 2 vector extends, so avoid that. + // TODO: We could ease the use-clause to "if at least one op has one use" + // (assuming that the source types match - see next TODO comment). + Value *Vec = InsElt.getOperand(0); + if (!Vec->hasOneUse()) + return nullptr; + + Value *Scalar = InsElt.getOperand(1); + Value *X, *Y; + CastInst::CastOps CastOpcode; + if (match(Vec, m_FPExt(m_Value(X))) && match(Scalar, m_FPExt(m_Value(Y)))) + CastOpcode = Instruction::FPExt; + else if (match(Vec, m_SExt(m_Value(X))) && match(Scalar, m_SExt(m_Value(Y)))) + CastOpcode = Instruction::SExt; + else if (match(Vec, m_ZExt(m_Value(X))) && match(Scalar, m_ZExt(m_Value(Y)))) + CastOpcode = Instruction::ZExt; + else + return nullptr; + + // TODO: We can allow mismatched types by creating an intermediate cast. + if (X->getType()->getScalarType() != Y->getType()) + return nullptr; + + // inselt (ext X), (ext Y), Index --> ext (inselt X, Y, Index) + Value *NewInsElt = Builder.CreateInsertElement(X, Y, InsElt.getOperand(2)); + return CastInst::Create(CastOpcode, NewInsElt, InsElt.getType()); +} + +Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { + Value *VecOp = IE.getOperand(0); + Value *ScalarOp = IE.getOperand(1); + Value *IdxOp = IE.getOperand(2); + + if (auto *V = simplifyInsertElementInst( + VecOp, ScalarOp, IdxOp, SQ.getWithInstruction(&IE))) + return replaceInstUsesWith(IE, V); + + // Canonicalize type of constant indices to i64 to simplify CSE + if (auto *IndexC = dyn_cast<ConstantInt>(IdxOp)) + if (auto *NewIdx = getPreferredVectorIndex(IndexC)) + return replaceOperand(IE, 2, NewIdx); + + // If the scalar is bitcast and inserted into undef, do the insert in the + // source type followed by bitcast. + // TODO: Generalize for insert into any constant, not just undef? + Value *ScalarSrc; + if (match(VecOp, m_Undef()) && + match(ScalarOp, m_OneUse(m_BitCast(m_Value(ScalarSrc)))) && + (ScalarSrc->getType()->isIntegerTy() || + ScalarSrc->getType()->isFloatingPointTy())) { + // inselt undef, (bitcast ScalarSrc), IdxOp --> + // bitcast (inselt undef, ScalarSrc, IdxOp) + Type *ScalarTy = ScalarSrc->getType(); + Type *VecTy = VectorType::get(ScalarTy, IE.getType()->getElementCount()); + UndefValue *NewUndef = UndefValue::get(VecTy); + Value *NewInsElt = Builder.CreateInsertElement(NewUndef, ScalarSrc, IdxOp); + return new BitCastInst(NewInsElt, IE.getType()); + } + + // If the vector and scalar are both bitcast from the same element type, do + // the insert in that source type followed by bitcast. + Value *VecSrc; + if (match(VecOp, m_BitCast(m_Value(VecSrc))) && + match(ScalarOp, m_BitCast(m_Value(ScalarSrc))) && + (VecOp->hasOneUse() || ScalarOp->hasOneUse()) && + VecSrc->getType()->isVectorTy() && !ScalarSrc->getType()->isVectorTy() && + cast<VectorType>(VecSrc->getType())->getElementType() == + ScalarSrc->getType()) { + // inselt (bitcast VecSrc), (bitcast ScalarSrc), IdxOp --> + // bitcast (inselt VecSrc, ScalarSrc, IdxOp) + Value *NewInsElt = Builder.CreateInsertElement(VecSrc, ScalarSrc, IdxOp); + return new BitCastInst(NewInsElt, IE.getType()); + } + + // If the inserted element was extracted from some other fixed-length vector + // and both indexes are valid constants, try to turn this into a shuffle. + // Can not handle scalable vector type, the number of elements needed to + // create shuffle mask is not a compile-time constant. + uint64_t InsertedIdx, ExtractedIdx; + Value *ExtVecOp; + if (isa<FixedVectorType>(IE.getType()) && + match(IdxOp, m_ConstantInt(InsertedIdx)) && + match(ScalarOp, + m_ExtractElt(m_Value(ExtVecOp), m_ConstantInt(ExtractedIdx))) && + isa<FixedVectorType>(ExtVecOp->getType()) && + ExtractedIdx < + cast<FixedVectorType>(ExtVecOp->getType())->getNumElements()) { + // TODO: Looking at the user(s) to determine if this insert is a + // fold-to-shuffle opportunity does not match the usual instcombine + // constraints. We should decide if the transform is worthy based only + // on this instruction and its operands, but that may not work currently. + // + // Here, we are trying to avoid creating shuffles before reaching + // the end of a chain of extract-insert pairs. This is complicated because + // we do not generally form arbitrary shuffle masks in instcombine + // (because those may codegen poorly), but collectShuffleElements() does + // exactly that. + // + // The rules for determining what is an acceptable target-independent + // shuffle mask are fuzzy because they evolve based on the backend's + // capabilities and real-world impact. + auto isShuffleRootCandidate = [](InsertElementInst &Insert) { + if (!Insert.hasOneUse()) + return true; + auto *InsertUser = dyn_cast<InsertElementInst>(Insert.user_back()); + if (!InsertUser) + return true; + return false; + }; + + // Try to form a shuffle from a chain of extract-insert ops. + if (isShuffleRootCandidate(IE)) { + SmallVector<int, 16> Mask; + ShuffleOps LR = collectShuffleElements(&IE, Mask, nullptr, *this); + + // The proposed shuffle may be trivial, in which case we shouldn't + // perform the combine. + if (LR.first != &IE && LR.second != &IE) { + // We now have a shuffle of LHS, RHS, Mask. + if (LR.second == nullptr) + LR.second = UndefValue::get(LR.first->getType()); + return new ShuffleVectorInst(LR.first, LR.second, Mask); + } + } + } + + if (auto VecTy = dyn_cast<FixedVectorType>(VecOp->getType())) { + unsigned VWidth = VecTy->getNumElements(); + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(&IE, AllOnesEltMask, UndefElts)) { + if (V != &IE) + return replaceInstUsesWith(IE, V); + return &IE; + } + } + + if (Instruction *Shuf = foldConstantInsEltIntoShuffle(IE)) + return Shuf; + + if (Instruction *NewInsElt = hoistInsEltConst(IE, Builder)) + return NewInsElt; + + if (Instruction *Broadcast = foldInsSequenceIntoSplat(IE)) + return Broadcast; + + if (Instruction *Splat = foldInsEltIntoSplat(IE)) + return Splat; + + if (Instruction *IdentityShuf = foldInsEltIntoIdentityShuffle(IE)) + return IdentityShuf; + + if (Instruction *Ext = narrowInsElt(IE, Builder)) + return Ext; + + return nullptr; +} + +/// Return true if we can evaluate the specified expression tree if the vector +/// elements were shuffled in a different order. +static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, + unsigned Depth = 5) { + // We can always reorder the elements of a constant. + if (isa<Constant>(V)) + return true; + + // We won't reorder vector arguments. No IPO here. + Instruction *I = dyn_cast<Instruction>(V); + if (!I) return false; + + // Two users may expect different orders of the elements. Don't try it. + if (!I->hasOneUse()) + return false; + + if (Depth == 0) return false; + + switch (I->getOpcode()) { + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + // Propagating an undefined shuffle mask element to integer div/rem is not + // allowed because those opcodes can create immediate undefined behavior + // from an undefined element in an operand. + if (llvm::is_contained(Mask, -1)) + return false; + LLVM_FALLTHROUGH; + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::ICmp: + case Instruction::FCmp: + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::GetElementPtr: { + // Bail out if we would create longer vector ops. We could allow creating + // longer vector ops, but that may result in more expensive codegen. + Type *ITy = I->getType(); + if (ITy->isVectorTy() && + Mask.size() > cast<FixedVectorType>(ITy)->getNumElements()) + return false; + for (Value *Operand : I->operands()) { + if (!canEvaluateShuffled(Operand, Mask, Depth - 1)) + return false; + } + return true; + } + case Instruction::InsertElement: { + ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(2)); + if (!CI) return false; + int ElementNumber = CI->getLimitedValue(); + + // Verify that 'CI' does not occur twice in Mask. A single 'insertelement' + // can't put an element into multiple indices. + bool SeenOnce = false; + for (int i = 0, e = Mask.size(); i != e; ++i) { + if (Mask[i] == ElementNumber) { + if (SeenOnce) + return false; + SeenOnce = true; + } + } + return canEvaluateShuffled(I->getOperand(0), Mask, Depth - 1); + } + } + return false; +} + +/// Rebuild a new instruction just like 'I' but with the new operands given. +/// In the event of type mismatch, the type of the operands is correct. +static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps) { + // We don't want to use the IRBuilder here because we want the replacement + // instructions to appear next to 'I', not the builder's insertion point. + switch (I->getOpcode()) { + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + BinaryOperator *BO = cast<BinaryOperator>(I); + assert(NewOps.size() == 2 && "binary operator with #ops != 2"); + BinaryOperator *New = + BinaryOperator::Create(cast<BinaryOperator>(I)->getOpcode(), + NewOps[0], NewOps[1], "", BO); + if (isa<OverflowingBinaryOperator>(BO)) { + New->setHasNoUnsignedWrap(BO->hasNoUnsignedWrap()); + New->setHasNoSignedWrap(BO->hasNoSignedWrap()); + } + if (isa<PossiblyExactOperator>(BO)) { + New->setIsExact(BO->isExact()); + } + if (isa<FPMathOperator>(BO)) + New->copyFastMathFlags(I); + return New; + } + case Instruction::ICmp: + assert(NewOps.size() == 2 && "icmp with #ops != 2"); + return new ICmpInst(I, cast<ICmpInst>(I)->getPredicate(), + NewOps[0], NewOps[1]); + case Instruction::FCmp: + assert(NewOps.size() == 2 && "fcmp with #ops != 2"); + return new FCmpInst(I, cast<FCmpInst>(I)->getPredicate(), + NewOps[0], NewOps[1]); + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPTrunc: + case Instruction::FPExt: { + // It's possible that the mask has a different number of elements from + // the original cast. We recompute the destination type to match the mask. + Type *DestTy = VectorType::get( + I->getType()->getScalarType(), + cast<VectorType>(NewOps[0]->getType())->getElementCount()); + assert(NewOps.size() == 1 && "cast with #ops != 1"); + return CastInst::Create(cast<CastInst>(I)->getOpcode(), NewOps[0], DestTy, + "", I); + } + case Instruction::GetElementPtr: { + Value *Ptr = NewOps[0]; + ArrayRef<Value*> Idx = NewOps.slice(1); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + cast<GetElementPtrInst>(I)->getSourceElementType(), Ptr, Idx, "", I); + GEP->setIsInBounds(cast<GetElementPtrInst>(I)->isInBounds()); + return GEP; + } + } + llvm_unreachable("failed to rebuild vector instructions"); +} + +static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { + // Mask.size() does not need to be equal to the number of vector elements. + + assert(V->getType()->isVectorTy() && "can't reorder non-vector elements"); + Type *EltTy = V->getType()->getScalarType(); + Type *I32Ty = IntegerType::getInt32Ty(V->getContext()); + if (match(V, m_Undef())) + return UndefValue::get(FixedVectorType::get(EltTy, Mask.size())); + + if (isa<ConstantAggregateZero>(V)) + return ConstantAggregateZero::get(FixedVectorType::get(EltTy, Mask.size())); + + if (Constant *C = dyn_cast<Constant>(V)) + return ConstantExpr::getShuffleVector(C, PoisonValue::get(C->getType()), + Mask); + + Instruction *I = cast<Instruction>(V); + switch (I->getOpcode()) { + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::ICmp: + case Instruction::FCmp: + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::Select: + case Instruction::GetElementPtr: { + SmallVector<Value*, 8> NewOps; + bool NeedsRebuild = + (Mask.size() != + cast<FixedVectorType>(I->getType())->getNumElements()); + for (int i = 0, e = I->getNumOperands(); i != e; ++i) { + Value *V; + // Recursively call evaluateInDifferentElementOrder on vector arguments + // as well. E.g. GetElementPtr may have scalar operands even if the + // return value is a vector, so we need to examine the operand type. + if (I->getOperand(i)->getType()->isVectorTy()) + V = evaluateInDifferentElementOrder(I->getOperand(i), Mask); + else + V = I->getOperand(i); + NewOps.push_back(V); + NeedsRebuild |= (V != I->getOperand(i)); + } + if (NeedsRebuild) { + return buildNew(I, NewOps); + } + return I; + } + case Instruction::InsertElement: { + int Element = cast<ConstantInt>(I->getOperand(2))->getLimitedValue(); + + // The insertelement was inserting at Element. Figure out which element + // that becomes after shuffling. The answer is guaranteed to be unique + // by CanEvaluateShuffled. + bool Found = false; + int Index = 0; + for (int e = Mask.size(); Index != e; ++Index) { + if (Mask[Index] == Element) { + Found = true; + break; + } + } + + // If element is not in Mask, no need to handle the operand 1 (element to + // be inserted). Just evaluate values in operand 0 according to Mask. + if (!Found) + return evaluateInDifferentElementOrder(I->getOperand(0), Mask); + + Value *V = evaluateInDifferentElementOrder(I->getOperand(0), Mask); + return InsertElementInst::Create(V, I->getOperand(1), + ConstantInt::get(I32Ty, Index), "", I); + } + } + llvm_unreachable("failed to reorder elements of vector instruction!"); +} + +// Returns true if the shuffle is extracting a contiguous range of values from +// LHS, for example: +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// Input: |AA|BB|CC|DD|EE|FF|GG|HH|II|JJ|KK|LL|MM|NN|OO|PP| +// Shuffles to: |EE|FF|GG|HH| +// +--+--+--+--+ +static bool isShuffleExtractingFromLHS(ShuffleVectorInst &SVI, + ArrayRef<int> Mask) { + unsigned LHSElems = + cast<FixedVectorType>(SVI.getOperand(0)->getType())->getNumElements(); + unsigned MaskElems = Mask.size(); + unsigned BegIdx = Mask.front(); + unsigned EndIdx = Mask.back(); + if (BegIdx > EndIdx || EndIdx >= LHSElems || EndIdx - BegIdx != MaskElems - 1) + return false; + for (unsigned I = 0; I != MaskElems; ++I) + if (static_cast<unsigned>(Mask[I]) != BegIdx + I) + return false; + return true; +} + +/// These are the ingredients in an alternate form binary operator as described +/// below. +struct BinopElts { + BinaryOperator::BinaryOps Opcode; + Value *Op0; + Value *Op1; + BinopElts(BinaryOperator::BinaryOps Opc = (BinaryOperator::BinaryOps)0, + Value *V0 = nullptr, Value *V1 = nullptr) : + Opcode(Opc), Op0(V0), Op1(V1) {} + operator bool() const { return Opcode != 0; } +}; + +/// Binops may be transformed into binops with different opcodes and operands. +/// Reverse the usual canonicalization to enable folds with the non-canonical +/// form of the binop. If a transform is possible, return the elements of the +/// new binop. If not, return invalid elements. +static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) { + Value *BO0 = BO->getOperand(0), *BO1 = BO->getOperand(1); + Type *Ty = BO->getType(); + switch (BO->getOpcode()) { + case Instruction::Shl: { + // shl X, C --> mul X, (1 << C) + Constant *C; + if (match(BO1, m_Constant(C))) { + Constant *ShlOne = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C); + return {Instruction::Mul, BO0, ShlOne}; + } + break; + } + case Instruction::Or: { + // or X, C --> add X, C (when X and C have no common bits set) + const APInt *C; + if (match(BO1, m_APInt(C)) && MaskedValueIsZero(BO0, *C, DL)) + return {Instruction::Add, BO0, BO1}; + break; + } + case Instruction::Sub: + // sub 0, X --> mul X, -1 + if (match(BO0, m_ZeroInt())) + return {Instruction::Mul, BO1, ConstantInt::getAllOnesValue(Ty)}; + break; + default: + break; + } + return {}; +} + +static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { + assert(Shuf.isSelect() && "Must have select-equivalent shuffle"); + + // Are we shuffling together some value and that same value after it has been + // modified by a binop with a constant? + Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); + Constant *C; + bool Op0IsBinop; + if (match(Op0, m_BinOp(m_Specific(Op1), m_Constant(C)))) + Op0IsBinop = true; + else if (match(Op1, m_BinOp(m_Specific(Op0), m_Constant(C)))) + Op0IsBinop = false; + else + return nullptr; + + // The identity constant for a binop leaves a variable operand unchanged. For + // a vector, this is a splat of something like 0, -1, or 1. + // If there's no identity constant for this binop, we're done. + auto *BO = cast<BinaryOperator>(Op0IsBinop ? Op0 : Op1); + BinaryOperator::BinaryOps BOpcode = BO->getOpcode(); + Constant *IdC = ConstantExpr::getBinOpIdentity(BOpcode, Shuf.getType(), true); + if (!IdC) + return nullptr; + + // Shuffle identity constants into the lanes that return the original value. + // Example: shuf (mul X, {-1,-2,-3,-4}), X, {0,5,6,3} --> mul X, {-1,1,1,-4} + // Example: shuf X, (add X, {-1,-2,-3,-4}), {0,1,6,7} --> add X, {0,0,-3,-4} + // The existing binop constant vector remains in the same operand position. + ArrayRef<int> Mask = Shuf.getShuffleMask(); + Constant *NewC = Op0IsBinop ? ConstantExpr::getShuffleVector(C, IdC, Mask) : + ConstantExpr::getShuffleVector(IdC, C, Mask); + + bool MightCreatePoisonOrUB = + is_contained(Mask, UndefMaskElem) && + (Instruction::isIntDivRem(BOpcode) || Instruction::isShift(BOpcode)); + if (MightCreatePoisonOrUB) + NewC = InstCombiner::getSafeVectorConstantForBinop(BOpcode, NewC, true); + + // shuf (bop X, C), X, M --> bop X, C' + // shuf X, (bop X, C), M --> bop X, C' + Value *X = Op0IsBinop ? Op1 : Op0; + Instruction *NewBO = BinaryOperator::Create(BOpcode, X, NewC); + NewBO->copyIRFlags(BO); + + // An undef shuffle mask element may propagate as an undef constant element in + // the new binop. That would produce poison where the original code might not. + // If we already made a safe constant, then there's no danger. + if (is_contained(Mask, UndefMaskElem) && !MightCreatePoisonOrUB) + NewBO->dropPoisonGeneratingFlags(); + return NewBO; +} + +/// If we have an insert of a scalar to a non-zero element of an undefined +/// vector and then shuffle that value, that's the same as inserting to the zero +/// element and shuffling. Splatting from the zero element is recognized as the +/// canonical form of splat. +static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, + InstCombiner::BuilderTy &Builder) { + Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); + ArrayRef<int> Mask = Shuf.getShuffleMask(); + Value *X; + uint64_t IndexC; + + // Match a shuffle that is a splat to a non-zero element. + if (!match(Op0, m_OneUse(m_InsertElt(m_Undef(), m_Value(X), + m_ConstantInt(IndexC)))) || + !match(Op1, m_Undef()) || match(Mask, m_ZeroMask()) || IndexC == 0) + return nullptr; + + // Insert into element 0 of an undef vector. + UndefValue *UndefVec = UndefValue::get(Shuf.getType()); + Constant *Zero = Builder.getInt32(0); + Value *NewIns = Builder.CreateInsertElement(UndefVec, X, Zero); + + // Splat from element 0. Any mask element that is undefined remains undefined. + // For example: + // shuf (inselt undef, X, 2), _, <2,2,undef> + // --> shuf (inselt undef, X, 0), poison, <0,0,undef> + unsigned NumMaskElts = + cast<FixedVectorType>(Shuf.getType())->getNumElements(); + SmallVector<int, 16> NewMask(NumMaskElts, 0); + for (unsigned i = 0; i != NumMaskElts; ++i) + if (Mask[i] == UndefMaskElem) + NewMask[i] = Mask[i]; + + return new ShuffleVectorInst(NewIns, NewMask); +} + +/// Try to fold shuffles that are the equivalent of a vector select. +Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) { + if (!Shuf.isSelect()) + return nullptr; + + // Canonicalize to choose from operand 0 first unless operand 1 is undefined. + // Commuting undef to operand 0 conflicts with another canonicalization. + unsigned NumElts = cast<FixedVectorType>(Shuf.getType())->getNumElements(); + if (!match(Shuf.getOperand(1), m_Undef()) && + Shuf.getMaskValue(0) >= (int)NumElts) { + // TODO: Can we assert that both operands of a shuffle-select are not undef + // (otherwise, it would have been folded by instsimplify? + Shuf.commute(); + return &Shuf; + } + + if (Instruction *I = foldSelectShuffleWith1Binop(Shuf)) + return I; + + BinaryOperator *B0, *B1; + if (!match(Shuf.getOperand(0), m_BinOp(B0)) || + !match(Shuf.getOperand(1), m_BinOp(B1))) + return nullptr; + + // If one operand is "0 - X", allow that to be viewed as "X * -1" + // (ConstantsAreOp1) by getAlternateBinop below. If the neg is not paired + // with a multiply, we will exit because C0/C1 will not be set. + Value *X, *Y; + Constant *C0 = nullptr, *C1 = nullptr; + bool ConstantsAreOp1; + if (match(B0, m_BinOp(m_Constant(C0), m_Value(X))) && + match(B1, m_BinOp(m_Constant(C1), m_Value(Y)))) + ConstantsAreOp1 = false; + else if (match(B0, m_CombineOr(m_BinOp(m_Value(X), m_Constant(C0)), + m_Neg(m_Value(X)))) && + match(B1, m_CombineOr(m_BinOp(m_Value(Y), m_Constant(C1)), + m_Neg(m_Value(Y))))) + ConstantsAreOp1 = true; + else + return nullptr; + + // We need matching binops to fold the lanes together. + BinaryOperator::BinaryOps Opc0 = B0->getOpcode(); + BinaryOperator::BinaryOps Opc1 = B1->getOpcode(); + bool DropNSW = false; + if (ConstantsAreOp1 && Opc0 != Opc1) { + // TODO: We drop "nsw" if shift is converted into multiply because it may + // not be correct when the shift amount is BitWidth - 1. We could examine + // each vector element to determine if it is safe to keep that flag. + if (Opc0 == Instruction::Shl || Opc1 == Instruction::Shl) + DropNSW = true; + if (BinopElts AltB0 = getAlternateBinop(B0, DL)) { + assert(isa<Constant>(AltB0.Op1) && "Expecting constant with alt binop"); + Opc0 = AltB0.Opcode; + C0 = cast<Constant>(AltB0.Op1); + } else if (BinopElts AltB1 = getAlternateBinop(B1, DL)) { + assert(isa<Constant>(AltB1.Op1) && "Expecting constant with alt binop"); + Opc1 = AltB1.Opcode; + C1 = cast<Constant>(AltB1.Op1); + } + } + + if (Opc0 != Opc1 || !C0 || !C1) + return nullptr; + + // The opcodes must be the same. Use a new name to make that clear. + BinaryOperator::BinaryOps BOpc = Opc0; + + // Select the constant elements needed for the single binop. + ArrayRef<int> Mask = Shuf.getShuffleMask(); + Constant *NewC = ConstantExpr::getShuffleVector(C0, C1, Mask); + + // We are moving a binop after a shuffle. When a shuffle has an undefined + // mask element, the result is undefined, but it is not poison or undefined + // behavior. That is not necessarily true for div/rem/shift. + bool MightCreatePoisonOrUB = + is_contained(Mask, UndefMaskElem) && + (Instruction::isIntDivRem(BOpc) || Instruction::isShift(BOpc)); + if (MightCreatePoisonOrUB) + NewC = InstCombiner::getSafeVectorConstantForBinop(BOpc, NewC, + ConstantsAreOp1); + + Value *V; + if (X == Y) { + // Remove a binop and the shuffle by rearranging the constant: + // shuffle (op V, C0), (op V, C1), M --> op V, C' + // shuffle (op C0, V), (op C1, V), M --> op C', V + V = X; + } else { + // If there are 2 different variable operands, we must create a new shuffle + // (select) first, so check uses to ensure that we don't end up with more + // instructions than we started with. + if (!B0->hasOneUse() && !B1->hasOneUse()) + return nullptr; + + // If we use the original shuffle mask and op1 is *variable*, we would be + // putting an undef into operand 1 of div/rem/shift. This is either UB or + // poison. We do not have to guard against UB when *constants* are op1 + // because safe constants guarantee that we do not overflow sdiv/srem (and + // there's no danger for other opcodes). + // TODO: To allow this case, create a new shuffle mask with no undefs. + if (MightCreatePoisonOrUB && !ConstantsAreOp1) + return nullptr; + + // Note: In general, we do not create new shuffles in InstCombine because we + // do not know if a target can lower an arbitrary shuffle optimally. In this + // case, the shuffle uses the existing mask, so there is no additional risk. + + // Select the variable vectors first, then perform the binop: + // shuffle (op X, C0), (op Y, C1), M --> op (shuffle X, Y, M), C' + // shuffle (op C0, X), (op C1, Y), M --> op C', (shuffle X, Y, M) + V = Builder.CreateShuffleVector(X, Y, Mask); + } + + Value *NewBO = ConstantsAreOp1 ? Builder.CreateBinOp(BOpc, V, NewC) : + Builder.CreateBinOp(BOpc, NewC, V); + + // Flags are intersected from the 2 source binops. But there are 2 exceptions: + // 1. If we changed an opcode, poison conditions might have changed. + // 2. If the shuffle had undef mask elements, the new binop might have undefs + // where the original code did not. But if we already made a safe constant, + // then there's no danger. + if (auto *NewI = dyn_cast<Instruction>(NewBO)) { + NewI->copyIRFlags(B0); + NewI->andIRFlags(B1); + if (DropNSW) + NewI->setHasNoSignedWrap(false); + if (is_contained(Mask, UndefMaskElem) && !MightCreatePoisonOrUB) + NewI->dropPoisonGeneratingFlags(); + } + return replaceInstUsesWith(Shuf, NewBO); +} + +/// Convert a narrowing shuffle of a bitcasted vector into a vector truncate. +/// Example (little endian): +/// shuf (bitcast <4 x i16> X to <8 x i8>), <0, 2, 4, 6> --> trunc X to <4 x i8> +static Instruction *foldTruncShuffle(ShuffleVectorInst &Shuf, + bool IsBigEndian) { + // This must be a bitcasted shuffle of 1 vector integer operand. + Type *DestType = Shuf.getType(); + Value *X; + if (!match(Shuf.getOperand(0), m_BitCast(m_Value(X))) || + !match(Shuf.getOperand(1), m_Undef()) || !DestType->isIntOrIntVectorTy()) + return nullptr; + + // The source type must have the same number of elements as the shuffle, + // and the source element type must be larger than the shuffle element type. + Type *SrcType = X->getType(); + if (!SrcType->isVectorTy() || !SrcType->isIntOrIntVectorTy() || + cast<FixedVectorType>(SrcType)->getNumElements() != + cast<FixedVectorType>(DestType)->getNumElements() || + SrcType->getScalarSizeInBits() % DestType->getScalarSizeInBits() != 0) + return nullptr; + + assert(Shuf.changesLength() && !Shuf.increasesLength() && + "Expected a shuffle that decreases length"); + + // Last, check that the mask chooses the correct low bits for each narrow + // element in the result. + uint64_t TruncRatio = + SrcType->getScalarSizeInBits() / DestType->getScalarSizeInBits(); + ArrayRef<int> Mask = Shuf.getShuffleMask(); + for (unsigned i = 0, e = Mask.size(); i != e; ++i) { + if (Mask[i] == UndefMaskElem) + continue; + uint64_t LSBIndex = IsBigEndian ? (i + 1) * TruncRatio - 1 : i * TruncRatio; + assert(LSBIndex <= INT32_MAX && "Overflowed 32-bits"); + if (Mask[i] != (int)LSBIndex) + return nullptr; + } + + return new TruncInst(X, DestType); +} + +/// Match a shuffle-select-shuffle pattern where the shuffles are widening and +/// narrowing (concatenating with undef and extracting back to the original +/// length). This allows replacing the wide select with a narrow select. +static Instruction *narrowVectorSelect(ShuffleVectorInst &Shuf, + InstCombiner::BuilderTy &Builder) { + // This must be a narrowing identity shuffle. It extracts the 1st N elements + // of the 1st vector operand of a shuffle. + if (!match(Shuf.getOperand(1), m_Undef()) || !Shuf.isIdentityWithExtract()) + return nullptr; + + // The vector being shuffled must be a vector select that we can eliminate. + // TODO: The one-use requirement could be eased if X and/or Y are constants. + Value *Cond, *X, *Y; + if (!match(Shuf.getOperand(0), + m_OneUse(m_Select(m_Value(Cond), m_Value(X), m_Value(Y))))) + return nullptr; + + // We need a narrow condition value. It must be extended with undef elements + // and have the same number of elements as this shuffle. + unsigned NarrowNumElts = + cast<FixedVectorType>(Shuf.getType())->getNumElements(); + Value *NarrowCond; + if (!match(Cond, m_OneUse(m_Shuffle(m_Value(NarrowCond), m_Undef()))) || + cast<FixedVectorType>(NarrowCond->getType())->getNumElements() != + NarrowNumElts || + !cast<ShuffleVectorInst>(Cond)->isIdentityWithPadding()) + return nullptr; + + // shuf (sel (shuf NarrowCond, undef, WideMask), X, Y), undef, NarrowMask) --> + // sel NarrowCond, (shuf X, undef, NarrowMask), (shuf Y, undef, NarrowMask) + Value *NarrowX = Builder.CreateShuffleVector(X, Shuf.getShuffleMask()); + Value *NarrowY = Builder.CreateShuffleVector(Y, Shuf.getShuffleMask()); + return SelectInst::Create(NarrowCond, NarrowX, NarrowY); +} + +/// Canonicalize FP negate after shuffle. +static Instruction *foldFNegShuffle(ShuffleVectorInst &Shuf, + InstCombiner::BuilderTy &Builder) { + Instruction *FNeg0; + Value *X; + if (!match(Shuf.getOperand(0), m_CombineAnd(m_Instruction(FNeg0), + m_FNeg(m_Value(X))))) + return nullptr; + + // shuffle (fneg X), Mask --> fneg (shuffle X, Mask) + if (FNeg0->hasOneUse() && match(Shuf.getOperand(1), m_Undef())) { + Value *NewShuf = Builder.CreateShuffleVector(X, Shuf.getShuffleMask()); + return UnaryOperator::CreateFNegFMF(NewShuf, FNeg0); + } + + Instruction *FNeg1; + Value *Y; + if (!match(Shuf.getOperand(1), m_CombineAnd(m_Instruction(FNeg1), + m_FNeg(m_Value(Y))))) + return nullptr; + + // shuffle (fneg X), (fneg Y), Mask --> fneg (shuffle X, Y, Mask) + if (FNeg0->hasOneUse() || FNeg1->hasOneUse()) { + Value *NewShuf = Builder.CreateShuffleVector(X, Y, Shuf.getShuffleMask()); + Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewShuf); + NewFNeg->copyIRFlags(FNeg0); + NewFNeg->andIRFlags(FNeg1); + return NewFNeg; + } + + return nullptr; +} + +/// Canonicalize casts after shuffle. +static Instruction *foldCastShuffle(ShuffleVectorInst &Shuf, + InstCombiner::BuilderTy &Builder) { + // Do we have 2 matching cast operands? + auto *Cast0 = dyn_cast<CastInst>(Shuf.getOperand(0)); + auto *Cast1 = dyn_cast<CastInst>(Shuf.getOperand(1)); + if (!Cast0 || !Cast1 || Cast0->getOpcode() != Cast1->getOpcode() || + Cast0->getSrcTy() != Cast1->getSrcTy()) + return nullptr; + + // TODO: Allow other opcodes? That would require easing the type restrictions + // below here. + CastInst::CastOps CastOpcode = Cast0->getOpcode(); + switch (CastOpcode) { + case Instruction::FPToSI: + case Instruction::FPToUI: + case Instruction::SIToFP: + case Instruction::UIToFP: + break; + default: + return nullptr; + } + + VectorType *ShufTy = Shuf.getType(); + VectorType *ShufOpTy = cast<VectorType>(Shuf.getOperand(0)->getType()); + VectorType *CastSrcTy = cast<VectorType>(Cast0->getSrcTy()); + + // TODO: Allow length-increasing shuffles? + if (ShufTy->getElementCount().getKnownMinValue() > + ShufOpTy->getElementCount().getKnownMinValue()) + return nullptr; + + // TODO: Allow element-size-decreasing casts (ex: fptosi float to i8)? + assert(isa<FixedVectorType>(CastSrcTy) && isa<FixedVectorType>(ShufOpTy) && + "Expected fixed vector operands for casts and binary shuffle"); + if (CastSrcTy->getPrimitiveSizeInBits() > ShufOpTy->getPrimitiveSizeInBits()) + return nullptr; + + // At least one of the operands must have only one use (the shuffle). + if (!Cast0->hasOneUse() && !Cast1->hasOneUse()) + return nullptr; + + // shuffle (cast X), (cast Y), Mask --> cast (shuffle X, Y, Mask) + Value *X = Cast0->getOperand(0); + Value *Y = Cast1->getOperand(0); + Value *NewShuf = Builder.CreateShuffleVector(X, Y, Shuf.getShuffleMask()); + return CastInst::Create(CastOpcode, NewShuf, ShufTy); +} + +/// Try to fold an extract subvector operation. +static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { + Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); + if (!Shuf.isIdentityWithExtract() || !match(Op1, m_Undef())) + return nullptr; + + // Check if we are extracting all bits of an inserted scalar: + // extract-subvec (bitcast (inselt ?, X, 0) --> bitcast X to subvec type + Value *X; + if (match(Op0, m_BitCast(m_InsertElt(m_Value(), m_Value(X), m_Zero()))) && + X->getType()->getPrimitiveSizeInBits() == + Shuf.getType()->getPrimitiveSizeInBits()) + return new BitCastInst(X, Shuf.getType()); + + // Try to combine 2 shuffles into 1 shuffle by concatenating a shuffle mask. + Value *Y; + ArrayRef<int> Mask; + if (!match(Op0, m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) + return nullptr; + + // Be conservative with shuffle transforms. If we can't kill the 1st shuffle, + // then combining may result in worse codegen. + if (!Op0->hasOneUse()) + return nullptr; + + // We are extracting a subvector from a shuffle. Remove excess elements from + // the 1st shuffle mask to eliminate the extract. + // + // This transform is conservatively limited to identity extracts because we do + // not allow arbitrary shuffle mask creation as a target-independent transform + // (because we can't guarantee that will lower efficiently). + // + // If the extracting shuffle has an undef mask element, it transfers to the + // new shuffle mask. Otherwise, copy the original mask element. Example: + // shuf (shuf X, Y, <C0, C1, C2, undef, C4>), undef, <0, undef, 2, 3> --> + // shuf X, Y, <C0, undef, C2, undef> + unsigned NumElts = cast<FixedVectorType>(Shuf.getType())->getNumElements(); + SmallVector<int, 16> NewMask(NumElts); + assert(NumElts < Mask.size() && + "Identity with extract must have less elements than its inputs"); + + for (unsigned i = 0; i != NumElts; ++i) { + int ExtractMaskElt = Shuf.getMaskValue(i); + int MaskElt = Mask[i]; + NewMask[i] = ExtractMaskElt == UndefMaskElem ? ExtractMaskElt : MaskElt; + } + return new ShuffleVectorInst(X, Y, NewMask); +} + +/// Try to replace a shuffle with an insertelement or try to replace a shuffle +/// operand with the operand of an insertelement. +static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf, + InstCombinerImpl &IC) { + Value *V0 = Shuf.getOperand(0), *V1 = Shuf.getOperand(1); + SmallVector<int, 16> Mask; + Shuf.getShuffleMask(Mask); + + int NumElts = Mask.size(); + int InpNumElts = cast<FixedVectorType>(V0->getType())->getNumElements(); + + // This is a specialization of a fold in SimplifyDemandedVectorElts. We may + // not be able to handle it there if the insertelement has >1 use. + // If the shuffle has an insertelement operand but does not choose the + // inserted scalar element from that value, then we can replace that shuffle + // operand with the source vector of the insertelement. + Value *X; + uint64_t IdxC; + if (match(V0, m_InsertElt(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { + // shuf (inselt X, ?, IdxC), ?, Mask --> shuf X, ?, Mask + if (!is_contained(Mask, (int)IdxC)) + return IC.replaceOperand(Shuf, 0, X); + } + if (match(V1, m_InsertElt(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { + // Offset the index constant by the vector width because we are checking for + // accesses to the 2nd vector input of the shuffle. + IdxC += InpNumElts; + // shuf ?, (inselt X, ?, IdxC), Mask --> shuf ?, X, Mask + if (!is_contained(Mask, (int)IdxC)) + return IC.replaceOperand(Shuf, 1, X); + } + // For the rest of the transform, the shuffle must not change vector sizes. + // TODO: This restriction could be removed if the insert has only one use + // (because the transform would require a new length-changing shuffle). + if (NumElts != InpNumElts) + return nullptr; + + // shuffle (insert ?, Scalar, IndexC), V1, Mask --> insert V1, Scalar, IndexC' + auto isShufflingScalarIntoOp1 = [&](Value *&Scalar, ConstantInt *&IndexC) { + // We need an insertelement with a constant index. + if (!match(V0, m_InsertElt(m_Value(), m_Value(Scalar), + m_ConstantInt(IndexC)))) + return false; + + // Test the shuffle mask to see if it splices the inserted scalar into the + // operand 1 vector of the shuffle. + int NewInsIndex = -1; + for (int i = 0; i != NumElts; ++i) { + // Ignore undef mask elements. + if (Mask[i] == -1) + continue; + + // The shuffle takes elements of operand 1 without lane changes. + if (Mask[i] == NumElts + i) + continue; + + // The shuffle must choose the inserted scalar exactly once. + if (NewInsIndex != -1 || Mask[i] != IndexC->getSExtValue()) + return false; + + // The shuffle is placing the inserted scalar into element i. + NewInsIndex = i; + } + + assert(NewInsIndex != -1 && "Did not fold shuffle with unused operand?"); + + // Index is updated to the potentially translated insertion lane. + IndexC = ConstantInt::get(IndexC->getType(), NewInsIndex); + return true; + }; + + // If the shuffle is unnecessary, insert the scalar operand directly into + // operand 1 of the shuffle. Example: + // shuffle (insert ?, S, 1), V1, <1, 5, 6, 7> --> insert V1, S, 0 + Value *Scalar; + ConstantInt *IndexC; + if (isShufflingScalarIntoOp1(Scalar, IndexC)) + return InsertElementInst::Create(V1, Scalar, IndexC); + + // Try again after commuting shuffle. Example: + // shuffle V0, (insert ?, S, 0), <0, 1, 2, 4> --> + // shuffle (insert ?, S, 0), V0, <4, 5, 6, 0> --> insert V0, S, 3 + std::swap(V0, V1); + ShuffleVectorInst::commuteShuffleMask(Mask, NumElts); + if (isShufflingScalarIntoOp1(Scalar, IndexC)) + return InsertElementInst::Create(V1, Scalar, IndexC); + + return nullptr; +} + +static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { + // Match the operands as identity with padding (also known as concatenation + // with undef) shuffles of the same source type. The backend is expected to + // recreate these concatenations from a shuffle of narrow operands. + auto *Shuffle0 = dyn_cast<ShuffleVectorInst>(Shuf.getOperand(0)); + auto *Shuffle1 = dyn_cast<ShuffleVectorInst>(Shuf.getOperand(1)); + if (!Shuffle0 || !Shuffle0->isIdentityWithPadding() || + !Shuffle1 || !Shuffle1->isIdentityWithPadding()) + return nullptr; + + // We limit this transform to power-of-2 types because we expect that the + // backend can convert the simplified IR patterns to identical nodes as the + // original IR. + // TODO: If we can verify the same behavior for arbitrary types, the + // power-of-2 checks can be removed. + Value *X = Shuffle0->getOperand(0); + Value *Y = Shuffle1->getOperand(0); + if (X->getType() != Y->getType() || + !isPowerOf2_32(cast<FixedVectorType>(Shuf.getType())->getNumElements()) || + !isPowerOf2_32( + cast<FixedVectorType>(Shuffle0->getType())->getNumElements()) || + !isPowerOf2_32(cast<FixedVectorType>(X->getType())->getNumElements()) || + match(X, m_Undef()) || match(Y, m_Undef())) + return nullptr; + assert(match(Shuffle0->getOperand(1), m_Undef()) && + match(Shuffle1->getOperand(1), m_Undef()) && + "Unexpected operand for identity shuffle"); + + // This is a shuffle of 2 widening shuffles. We can shuffle the narrow source + // operands directly by adjusting the shuffle mask to account for the narrower + // types: + // shuf (widen X), (widen Y), Mask --> shuf X, Y, Mask' + int NarrowElts = cast<FixedVectorType>(X->getType())->getNumElements(); + int WideElts = cast<FixedVectorType>(Shuffle0->getType())->getNumElements(); + assert(WideElts > NarrowElts && "Unexpected types for identity with padding"); + + ArrayRef<int> Mask = Shuf.getShuffleMask(); + SmallVector<int, 16> NewMask(Mask.size(), -1); + for (int i = 0, e = Mask.size(); i != e; ++i) { + if (Mask[i] == -1) + continue; + + // If this shuffle is choosing an undef element from 1 of the sources, that + // element is undef. + if (Mask[i] < WideElts) { + if (Shuffle0->getMaskValue(Mask[i]) == -1) + continue; + } else { + if (Shuffle1->getMaskValue(Mask[i] - WideElts) == -1) + continue; + } + + // If this shuffle is choosing from the 1st narrow op, the mask element is + // the same. If this shuffle is choosing from the 2nd narrow op, the mask + // element is offset down to adjust for the narrow vector widths. + if (Mask[i] < WideElts) { + assert(Mask[i] < NarrowElts && "Unexpected shuffle mask"); + NewMask[i] = Mask[i]; + } else { + assert(Mask[i] < (WideElts + NarrowElts) && "Unexpected shuffle mask"); + NewMask[i] = Mask[i] - (WideElts - NarrowElts); + } + } + return new ShuffleVectorInst(X, Y, NewMask); +} + +Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { + Value *LHS = SVI.getOperand(0); + Value *RHS = SVI.getOperand(1); + SimplifyQuery ShufQuery = SQ.getWithInstruction(&SVI); + if (auto *V = simplifyShuffleVectorInst(LHS, RHS, SVI.getShuffleMask(), + SVI.getType(), ShufQuery)) + return replaceInstUsesWith(SVI, V); + + // Bail out for scalable vectors + if (isa<ScalableVectorType>(LHS->getType())) + return nullptr; + + unsigned VWidth = cast<FixedVectorType>(SVI.getType())->getNumElements(); + unsigned LHSWidth = cast<FixedVectorType>(LHS->getType())->getNumElements(); + + // shuffle (bitcast X), (bitcast Y), Mask --> bitcast (shuffle X, Y, Mask) + // + // if X and Y are of the same (vector) type, and the element size is not + // changed by the bitcasts, we can distribute the bitcasts through the + // shuffle, hopefully reducing the number of instructions. We make sure that + // at least one bitcast only has one use, so we don't *increase* the number of + // instructions here. + Value *X, *Y; + if (match(LHS, m_BitCast(m_Value(X))) && match(RHS, m_BitCast(m_Value(Y))) && + X->getType()->isVectorTy() && X->getType() == Y->getType() && + X->getType()->getScalarSizeInBits() == + SVI.getType()->getScalarSizeInBits() && + (LHS->hasOneUse() || RHS->hasOneUse())) { + Value *V = Builder.CreateShuffleVector(X, Y, SVI.getShuffleMask(), + SVI.getName() + ".uncasted"); + return new BitCastInst(V, SVI.getType()); + } + + ArrayRef<int> Mask = SVI.getShuffleMask(); + Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); + + // Peek through a bitcasted shuffle operand by scaling the mask. If the + // simulated shuffle can simplify, then this shuffle is unnecessary: + // shuf (bitcast X), undef, Mask --> bitcast X' + // TODO: This could be extended to allow length-changing shuffles. + // The transform might also be obsoleted if we allowed canonicalization + // of bitcasted shuffles. + if (match(LHS, m_BitCast(m_Value(X))) && match(RHS, m_Undef()) && + X->getType()->isVectorTy() && VWidth == LHSWidth) { + // Try to create a scaled mask constant. + auto *XType = cast<FixedVectorType>(X->getType()); + unsigned XNumElts = XType->getNumElements(); + SmallVector<int, 16> ScaledMask; + if (XNumElts >= VWidth) { + assert(XNumElts % VWidth == 0 && "Unexpected vector bitcast"); + narrowShuffleMaskElts(XNumElts / VWidth, Mask, ScaledMask); + } else { + assert(VWidth % XNumElts == 0 && "Unexpected vector bitcast"); + if (!widenShuffleMaskElts(VWidth / XNumElts, Mask, ScaledMask)) + ScaledMask.clear(); + } + if (!ScaledMask.empty()) { + // If the shuffled source vector simplifies, cast that value to this + // shuffle's type. + if (auto *V = simplifyShuffleVectorInst(X, UndefValue::get(XType), + ScaledMask, XType, ShufQuery)) + return BitCastInst::Create(Instruction::BitCast, V, SVI.getType()); + } + } + + // shuffle x, x, mask --> shuffle x, undef, mask' + if (LHS == RHS) { + assert(!match(RHS, m_Undef()) && + "Shuffle with 2 undef ops not simplified?"); + return new ShuffleVectorInst(LHS, createUnaryMask(Mask, LHSWidth)); + } + + // shuffle undef, x, mask --> shuffle x, undef, mask' + if (match(LHS, m_Undef())) { + SVI.commute(); + return &SVI; + } + + if (Instruction *I = canonicalizeInsertSplat(SVI, Builder)) + return I; + + if (Instruction *I = foldSelectShuffle(SVI)) + return I; + + if (Instruction *I = foldTruncShuffle(SVI, DL.isBigEndian())) + return I; + + if (Instruction *I = narrowVectorSelect(SVI, Builder)) + return I; + + if (Instruction *I = foldFNegShuffle(SVI, Builder)) + return I; + + if (Instruction *I = foldCastShuffle(SVI, Builder)) + return I; + + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts)) { + if (V != &SVI) + return replaceInstUsesWith(SVI, V); + return &SVI; + } + + if (Instruction *I = foldIdentityExtractShuffle(SVI)) + return I; + + // These transforms have the potential to lose undef knowledge, so they are + // intentionally placed after SimplifyDemandedVectorElts(). + if (Instruction *I = foldShuffleWithInsert(SVI, *this)) + return I; + if (Instruction *I = foldIdentityPaddedShuffles(SVI)) + return I; + + if (match(RHS, m_Undef()) && canEvaluateShuffled(LHS, Mask)) { + Value *V = evaluateInDifferentElementOrder(LHS, Mask); + return replaceInstUsesWith(SVI, V); + } + + // SROA generates shuffle+bitcast when the extracted sub-vector is bitcast to + // a non-vector type. We can instead bitcast the original vector followed by + // an extract of the desired element: + // + // %sroa = shufflevector <16 x i8> %in, <16 x i8> undef, + // <4 x i32> <i32 0, i32 1, i32 2, i32 3> + // %1 = bitcast <4 x i8> %sroa to i32 + // Becomes: + // %bc = bitcast <16 x i8> %in to <4 x i32> + // %ext = extractelement <4 x i32> %bc, i32 0 + // + // If the shuffle is extracting a contiguous range of values from the input + // vector then each use which is a bitcast of the extracted size can be + // replaced. This will work if the vector types are compatible, and the begin + // index is aligned to a value in the casted vector type. If the begin index + // isn't aligned then we can shuffle the original vector (keeping the same + // vector type) before extracting. + // + // This code will bail out if the target type is fundamentally incompatible + // with vectors of the source type. + // + // Example of <16 x i8>, target type i32: + // Index range [4,8): v-----------v Will work. + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + // <16 x i8>: | | | | | | | | | | | | | | | | | + // <4 x i32>: | | | | | + // +-----------+-----------+-----------+-----------+ + // Index range [6,10): ^-----------^ Needs an extra shuffle. + // Target type i40: ^--------------^ Won't work, bail. + bool MadeChange = false; + if (isShuffleExtractingFromLHS(SVI, Mask)) { + Value *V = LHS; + unsigned MaskElems = Mask.size(); + auto *SrcTy = cast<FixedVectorType>(V->getType()); + unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedSize(); + unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType()); + assert(SrcElemBitWidth && "vector elements must have a bitwidth"); + unsigned SrcNumElems = SrcTy->getNumElements(); + SmallVector<BitCastInst *, 8> BCs; + DenseMap<Type *, Value *> NewBCs; + for (User *U : SVI.users()) + if (BitCastInst *BC = dyn_cast<BitCastInst>(U)) + if (!BC->use_empty()) + // Only visit bitcasts that weren't previously handled. + BCs.push_back(BC); + for (BitCastInst *BC : BCs) { + unsigned BegIdx = Mask.front(); + Type *TgtTy = BC->getDestTy(); + unsigned TgtElemBitWidth = DL.getTypeSizeInBits(TgtTy); + if (!TgtElemBitWidth) + continue; + unsigned TgtNumElems = VecBitWidth / TgtElemBitWidth; + bool VecBitWidthsEqual = VecBitWidth == TgtNumElems * TgtElemBitWidth; + bool BegIsAligned = 0 == ((SrcElemBitWidth * BegIdx) % TgtElemBitWidth); + if (!VecBitWidthsEqual) + continue; + if (!VectorType::isValidElementType(TgtTy)) + continue; + auto *CastSrcTy = FixedVectorType::get(TgtTy, TgtNumElems); + if (!BegIsAligned) { + // Shuffle the input so [0,NumElements) contains the output, and + // [NumElems,SrcNumElems) is undef. + SmallVector<int, 16> ShuffleMask(SrcNumElems, -1); + for (unsigned I = 0, E = MaskElems, Idx = BegIdx; I != E; ++Idx, ++I) + ShuffleMask[I] = Idx; + V = Builder.CreateShuffleVector(V, ShuffleMask, + SVI.getName() + ".extract"); + BegIdx = 0; + } + unsigned SrcElemsPerTgtElem = TgtElemBitWidth / SrcElemBitWidth; + assert(SrcElemsPerTgtElem); + BegIdx /= SrcElemsPerTgtElem; + bool BCAlreadyExists = NewBCs.find(CastSrcTy) != NewBCs.end(); + auto *NewBC = + BCAlreadyExists + ? NewBCs[CastSrcTy] + : Builder.CreateBitCast(V, CastSrcTy, SVI.getName() + ".bc"); + if (!BCAlreadyExists) + NewBCs[CastSrcTy] = NewBC; + auto *Ext = Builder.CreateExtractElement( + NewBC, ConstantInt::get(Int32Ty, BegIdx), SVI.getName() + ".extract"); + // The shufflevector isn't being replaced: the bitcast that used it + // is. InstCombine will visit the newly-created instructions. + replaceInstUsesWith(*BC, Ext); + MadeChange = true; + } + } + + // If the LHS is a shufflevector itself, see if we can combine it with this + // one without producing an unusual shuffle. + // Cases that might be simplified: + // 1. + // x1=shuffle(v1,v2,mask1) + // x=shuffle(x1,undef,mask) + // ==> + // x=shuffle(v1,undef,newMask) + // newMask[i] = (mask[i] < x1.size()) ? mask1[mask[i]] : -1 + // 2. + // x1=shuffle(v1,undef,mask1) + // x=shuffle(x1,x2,mask) + // where v1.size() == mask1.size() + // ==> + // x=shuffle(v1,x2,newMask) + // newMask[i] = (mask[i] < x1.size()) ? mask1[mask[i]] : mask[i] + // 3. + // x2=shuffle(v2,undef,mask2) + // x=shuffle(x1,x2,mask) + // where v2.size() == mask2.size() + // ==> + // x=shuffle(x1,v2,newMask) + // newMask[i] = (mask[i] < x1.size()) + // ? mask[i] : mask2[mask[i]-x1.size()]+x1.size() + // 4. + // x1=shuffle(v1,undef,mask1) + // x2=shuffle(v2,undef,mask2) + // x=shuffle(x1,x2,mask) + // where v1.size() == v2.size() + // ==> + // x=shuffle(v1,v2,newMask) + // newMask[i] = (mask[i] < x1.size()) + // ? mask1[mask[i]] : mask2[mask[i]-x1.size()]+v1.size() + // + // Here we are really conservative: + // we are absolutely afraid of producing a shuffle mask not in the input + // program, because the code gen may not be smart enough to turn a merged + // shuffle into two specific shuffles: it may produce worse code. As such, + // we only merge two shuffles if the result is either a splat or one of the + // input shuffle masks. In this case, merging the shuffles just removes + // one instruction, which we know is safe. This is good for things like + // turning: (splat(splat)) -> splat, or + // merge(V[0..n], V[n+1..2n]) -> V[0..2n] + ShuffleVectorInst* LHSShuffle = dyn_cast<ShuffleVectorInst>(LHS); + ShuffleVectorInst* RHSShuffle = dyn_cast<ShuffleVectorInst>(RHS); + if (LHSShuffle) + if (!match(LHSShuffle->getOperand(1), m_Undef()) && !match(RHS, m_Undef())) + LHSShuffle = nullptr; + if (RHSShuffle) + if (!match(RHSShuffle->getOperand(1), m_Undef())) + RHSShuffle = nullptr; + if (!LHSShuffle && !RHSShuffle) + return MadeChange ? &SVI : nullptr; + + Value* LHSOp0 = nullptr; + Value* LHSOp1 = nullptr; + Value* RHSOp0 = nullptr; + unsigned LHSOp0Width = 0; + unsigned RHSOp0Width = 0; + if (LHSShuffle) { + LHSOp0 = LHSShuffle->getOperand(0); + LHSOp1 = LHSShuffle->getOperand(1); + LHSOp0Width = cast<FixedVectorType>(LHSOp0->getType())->getNumElements(); + } + if (RHSShuffle) { + RHSOp0 = RHSShuffle->getOperand(0); + RHSOp0Width = cast<FixedVectorType>(RHSOp0->getType())->getNumElements(); + } + Value* newLHS = LHS; + Value* newRHS = RHS; + if (LHSShuffle) { + // case 1 + if (match(RHS, m_Undef())) { + newLHS = LHSOp0; + newRHS = LHSOp1; + } + // case 2 or 4 + else if (LHSOp0Width == LHSWidth) { + newLHS = LHSOp0; + } + } + // case 3 or 4 + if (RHSShuffle && RHSOp0Width == LHSWidth) { + newRHS = RHSOp0; + } + // case 4 + if (LHSOp0 == RHSOp0) { + newLHS = LHSOp0; + newRHS = nullptr; + } + + if (newLHS == LHS && newRHS == RHS) + return MadeChange ? &SVI : nullptr; + + ArrayRef<int> LHSMask; + ArrayRef<int> RHSMask; + if (newLHS != LHS) + LHSMask = LHSShuffle->getShuffleMask(); + if (RHSShuffle && newRHS != RHS) + RHSMask = RHSShuffle->getShuffleMask(); + + unsigned newLHSWidth = (newLHS != LHS) ? LHSOp0Width : LHSWidth; + SmallVector<int, 16> newMask; + bool isSplat = true; + int SplatElt = -1; + // Create a new mask for the new ShuffleVectorInst so that the new + // ShuffleVectorInst is equivalent to the original one. + for (unsigned i = 0; i < VWidth; ++i) { + int eltMask; + if (Mask[i] < 0) { + // This element is an undef value. + eltMask = -1; + } else if (Mask[i] < (int)LHSWidth) { + // This element is from left hand side vector operand. + // + // If LHS is going to be replaced (case 1, 2, or 4), calculate the + // new mask value for the element. + if (newLHS != LHS) { + eltMask = LHSMask[Mask[i]]; + // If the value selected is an undef value, explicitly specify it + // with a -1 mask value. + if (eltMask >= (int)LHSOp0Width && isa<UndefValue>(LHSOp1)) + eltMask = -1; + } else + eltMask = Mask[i]; + } else { + // This element is from right hand side vector operand + // + // If the value selected is an undef value, explicitly specify it + // with a -1 mask value. (case 1) + if (match(RHS, m_Undef())) + eltMask = -1; + // If RHS is going to be replaced (case 3 or 4), calculate the + // new mask value for the element. + else if (newRHS != RHS) { + eltMask = RHSMask[Mask[i]-LHSWidth]; + // If the value selected is an undef value, explicitly specify it + // with a -1 mask value. + if (eltMask >= (int)RHSOp0Width) { + assert(match(RHSShuffle->getOperand(1), m_Undef()) && + "should have been check above"); + eltMask = -1; + } + } else + eltMask = Mask[i]-LHSWidth; + + // If LHS's width is changed, shift the mask value accordingly. + // If newRHS == nullptr, i.e. LHSOp0 == RHSOp0, we want to remap any + // references from RHSOp0 to LHSOp0, so we don't need to shift the mask. + // If newRHS == newLHS, we want to remap any references from newRHS to + // newLHS so that we can properly identify splats that may occur due to + // obfuscation across the two vectors. + if (eltMask >= 0 && newRHS != nullptr && newLHS != newRHS) + eltMask += newLHSWidth; + } + + // Check if this could still be a splat. + if (eltMask >= 0) { + if (SplatElt >= 0 && SplatElt != eltMask) + isSplat = false; + SplatElt = eltMask; + } + + newMask.push_back(eltMask); + } + + // If the result mask is equal to one of the original shuffle masks, + // or is a splat, do the replacement. + if (isSplat || newMask == LHSMask || newMask == RHSMask || newMask == Mask) { + if (!newRHS) + newRHS = UndefValue::get(newLHS->getType()); + return new ShuffleVectorInst(newLHS, newRHS, newMask); + } + + return MadeChange ? &SVI : nullptr; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp new file mode 100644 index 000000000000..0816a4a575d9 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -0,0 +1,4676 @@ +//===- InstructionCombining.cpp - Combine multiple instructions -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// InstructionCombining - Combine instructions to form fewer, simple +// instructions. This pass does not modify the CFG. This pass is where +// algebraic simplification happens. +// +// This pass combines things like: +// %Y = add i32 %X, 1 +// %Z = add i32 %Y, 1 +// into: +// %Z = add i32 %X, 2 +// +// This is a simple worklist driven algorithm. +// +// This pass guarantees that the following canonicalizations are performed on +// the program: +// 1. If a binary operator has a constant operand, it is moved to the RHS +// 2. Bitwise operators with constant operands are always grouped so that +// shifts are performed first, then or's, then and's, then xor's. +// 3. Compare instructions are converted from <,>,<=,>= to ==,!= if possible +// 4. All cmp instructions on boolean values are replaced with logical ops +// 5. add X, X is represented as (X*2) => (X << 1) +// 6. Multiplies with a power-of-two constant argument are transformed into +// shifts. +// ... etc. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm-c/Initialization.h" +#include "llvm-c/Transforms/InstCombine.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/EHPersonalities.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LazyBlockFrequencyInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TargetFolder.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/Utils/Local.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DIBuilder.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/InitializePasses.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Utils/Local.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <memory> +#include <string> +#include <utility> + +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +STATISTIC(NumWorklistIterations, + "Number of instruction combining iterations performed"); + +STATISTIC(NumCombined , "Number of insts combined"); +STATISTIC(NumConstProp, "Number of constant folds"); +STATISTIC(NumDeadInst , "Number of dead inst eliminated"); +STATISTIC(NumSunkInst , "Number of instructions sunk"); +STATISTIC(NumExpand, "Number of expansions"); +STATISTIC(NumFactor , "Number of factorizations"); +STATISTIC(NumReassoc , "Number of reassociations"); +DEBUG_COUNTER(VisitCounter, "instcombine-visit", + "Controls which instructions are visited"); + +// FIXME: these limits eventually should be as low as 2. +static constexpr unsigned InstCombineDefaultMaxIterations = 1000; +#ifndef NDEBUG +static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 100; +#else +static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 1000; +#endif + +static cl::opt<bool> +EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"), + cl::init(true)); + +static cl::opt<unsigned> MaxSinkNumUsers( + "instcombine-max-sink-users", cl::init(32), + cl::desc("Maximum number of undroppable users for instruction sinking")); + +static cl::opt<unsigned> LimitMaxIterations( + "instcombine-max-iterations", + cl::desc("Limit the maximum number of instruction combining iterations"), + cl::init(InstCombineDefaultMaxIterations)); + +static cl::opt<unsigned> InfiniteLoopDetectionThreshold( + "instcombine-infinite-loop-threshold", + cl::desc("Number of instruction combining iterations considered an " + "infinite loop"), + cl::init(InstCombineDefaultInfiniteLoopThreshold), cl::Hidden); + +static cl::opt<unsigned> +MaxArraySize("instcombine-maxarray-size", cl::init(1024), + cl::desc("Maximum array size considered when doing a combine")); + +// FIXME: Remove this flag when it is no longer necessary to convert +// llvm.dbg.declare to avoid inaccurate debug info. Setting this to false +// increases variable availability at the cost of accuracy. Variables that +// cannot be promoted by mem2reg or SROA will be described as living in memory +// for their entire lifetime. However, passes like DSE and instcombine can +// delete stores to the alloca, leading to misleading and inaccurate debug +// information. This flag can be removed when those passes are fixed. +static cl::opt<unsigned> ShouldLowerDbgDeclare("instcombine-lower-dbg-declare", + cl::Hidden, cl::init(true)); + +Optional<Instruction *> +InstCombiner::targetInstCombineIntrinsic(IntrinsicInst &II) { + // Handle target specific intrinsics + if (II.getCalledFunction()->isTargetIntrinsic()) { + return TTI.instCombineIntrinsic(*this, II); + } + return None; +} + +Optional<Value *> InstCombiner::targetSimplifyDemandedUseBitsIntrinsic( + IntrinsicInst &II, APInt DemandedMask, KnownBits &Known, + bool &KnownBitsComputed) { + // Handle target specific intrinsics + if (II.getCalledFunction()->isTargetIntrinsic()) { + return TTI.simplifyDemandedUseBitsIntrinsic(*this, II, DemandedMask, Known, + KnownBitsComputed); + } + return None; +} + +Optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic( + IntrinsicInst &II, APInt DemandedElts, APInt &UndefElts, APInt &UndefElts2, + APInt &UndefElts3, + std::function<void(Instruction *, unsigned, APInt, APInt &)> + SimplifyAndSetOp) { + // Handle target specific intrinsics + if (II.getCalledFunction()->isTargetIntrinsic()) { + return TTI.simplifyDemandedVectorEltsIntrinsic( + *this, II, DemandedElts, UndefElts, UndefElts2, UndefElts3, + SimplifyAndSetOp); + } + return None; +} + +Value *InstCombinerImpl::EmitGEPOffset(User *GEP) { + return llvm::EmitGEPOffset(&Builder, DL, GEP); +} + +/// Legal integers and common types are considered desirable. This is used to +/// avoid creating instructions with types that may not be supported well by the +/// the backend. +/// NOTE: This treats i8, i16 and i32 specially because they are common +/// types in frontend languages. +bool InstCombinerImpl::isDesirableIntType(unsigned BitWidth) const { + switch (BitWidth) { + case 8: + case 16: + case 32: + return true; + default: + return DL.isLegalInteger(BitWidth); + } +} + +/// Return true if it is desirable to convert an integer computation from a +/// given bit width to a new bit width. +/// We don't want to convert from a legal to an illegal type or from a smaller +/// to a larger illegal type. A width of '1' is always treated as a desirable +/// type because i1 is a fundamental type in IR, and there are many specialized +/// optimizations for i1 types. Common/desirable widths are equally treated as +/// legal to convert to, in order to open up more combining opportunities. +bool InstCombinerImpl::shouldChangeType(unsigned FromWidth, + unsigned ToWidth) const { + bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth); + bool ToLegal = ToWidth == 1 || DL.isLegalInteger(ToWidth); + + // Convert to desirable widths even if they are not legal types. + // Only shrink types, to prevent infinite loops. + if (ToWidth < FromWidth && isDesirableIntType(ToWidth)) + return true; + + // If this is a legal integer from type, and the result would be an illegal + // type, don't do the transformation. + if (FromLegal && !ToLegal) + return false; + + // Otherwise, if both are illegal, do not increase the size of the result. We + // do allow things like i160 -> i64, but not i64 -> i160. + if (!FromLegal && !ToLegal && ToWidth > FromWidth) + return false; + + return true; +} + +/// Return true if it is desirable to convert a computation from 'From' to 'To'. +/// We don't want to convert from a legal to an illegal type or from a smaller +/// to a larger illegal type. i1 is always treated as a legal type because it is +/// a fundamental type in IR, and there are many specialized optimizations for +/// i1 types. +bool InstCombinerImpl::shouldChangeType(Type *From, Type *To) const { + // TODO: This could be extended to allow vectors. Datalayout changes might be + // needed to properly support that. + if (!From->isIntegerTy() || !To->isIntegerTy()) + return false; + + unsigned FromWidth = From->getPrimitiveSizeInBits(); + unsigned ToWidth = To->getPrimitiveSizeInBits(); + return shouldChangeType(FromWidth, ToWidth); +} + +// Return true, if No Signed Wrap should be maintained for I. +// The No Signed Wrap flag can be kept if the operation "B (I.getOpcode) C", +// where both B and C should be ConstantInts, results in a constant that does +// not overflow. This function only handles the Add and Sub opcodes. For +// all other opcodes, the function conservatively returns false. +static bool maintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) { + auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I); + if (!OBO || !OBO->hasNoSignedWrap()) + return false; + + // We reason about Add and Sub Only. + Instruction::BinaryOps Opcode = I.getOpcode(); + if (Opcode != Instruction::Add && Opcode != Instruction::Sub) + return false; + + const APInt *BVal, *CVal; + if (!match(B, m_APInt(BVal)) || !match(C, m_APInt(CVal))) + return false; + + bool Overflow = false; + if (Opcode == Instruction::Add) + (void)BVal->sadd_ov(*CVal, Overflow); + else + (void)BVal->ssub_ov(*CVal, Overflow); + + return !Overflow; +} + +static bool hasNoUnsignedWrap(BinaryOperator &I) { + auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I); + return OBO && OBO->hasNoUnsignedWrap(); +} + +static bool hasNoSignedWrap(BinaryOperator &I) { + auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I); + return OBO && OBO->hasNoSignedWrap(); +} + +/// Conservatively clears subclassOptionalData after a reassociation or +/// commutation. We preserve fast-math flags when applicable as they can be +/// preserved. +static void ClearSubclassDataAfterReassociation(BinaryOperator &I) { + FPMathOperator *FPMO = dyn_cast<FPMathOperator>(&I); + if (!FPMO) { + I.clearSubclassOptionalData(); + return; + } + + FastMathFlags FMF = I.getFastMathFlags(); + I.clearSubclassOptionalData(); + I.setFastMathFlags(FMF); +} + +/// Combine constant operands of associative operations either before or after a +/// cast to eliminate one of the associative operations: +/// (op (cast (op X, C2)), C1) --> (cast (op X, op (C1, C2))) +/// (op (cast (op X, C2)), C1) --> (op (cast X), op (C1, C2)) +static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, + InstCombinerImpl &IC) { + auto *Cast = dyn_cast<CastInst>(BinOp1->getOperand(0)); + if (!Cast || !Cast->hasOneUse()) + return false; + + // TODO: Enhance logic for other casts and remove this check. + auto CastOpcode = Cast->getOpcode(); + if (CastOpcode != Instruction::ZExt) + return false; + + // TODO: Enhance logic for other BinOps and remove this check. + if (!BinOp1->isBitwiseLogicOp()) + return false; + + auto AssocOpcode = BinOp1->getOpcode(); + auto *BinOp2 = dyn_cast<BinaryOperator>(Cast->getOperand(0)); + if (!BinOp2 || !BinOp2->hasOneUse() || BinOp2->getOpcode() != AssocOpcode) + return false; + + Constant *C1, *C2; + if (!match(BinOp1->getOperand(1), m_Constant(C1)) || + !match(BinOp2->getOperand(1), m_Constant(C2))) + return false; + + // TODO: This assumes a zext cast. + // Eg, if it was a trunc, we'd cast C1 to the source type because casting C2 + // to the destination type might lose bits. + + // Fold the constants together in the destination type: + // (op (cast (op X, C2)), C1) --> (op (cast X), FoldedC) + Type *DestTy = C1->getType(); + Constant *CastC2 = ConstantExpr::getCast(CastOpcode, C2, DestTy); + Constant *FoldedC = ConstantExpr::get(AssocOpcode, C1, CastC2); + IC.replaceOperand(*Cast, 0, BinOp2->getOperand(0)); + IC.replaceOperand(*BinOp1, 1, FoldedC); + return true; +} + +// Simplifies IntToPtr/PtrToInt RoundTrip Cast To BitCast. +// inttoptr ( ptrtoint (x) ) --> x +Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) { + auto *IntToPtr = dyn_cast<IntToPtrInst>(Val); + if (IntToPtr && DL.getPointerTypeSizeInBits(IntToPtr->getDestTy()) == + DL.getTypeSizeInBits(IntToPtr->getSrcTy())) { + auto *PtrToInt = dyn_cast<PtrToIntInst>(IntToPtr->getOperand(0)); + Type *CastTy = IntToPtr->getDestTy(); + if (PtrToInt && + CastTy->getPointerAddressSpace() == + PtrToInt->getSrcTy()->getPointerAddressSpace() && + DL.getPointerTypeSizeInBits(PtrToInt->getSrcTy()) == + DL.getTypeSizeInBits(PtrToInt->getDestTy())) { + return CastInst::CreateBitOrPointerCast(PtrToInt->getOperand(0), CastTy, + "", PtrToInt); + } + } + return nullptr; +} + +/// This performs a few simplifications for operators that are associative or +/// commutative: +/// +/// Commutative operators: +/// +/// 1. Order operands such that they are listed from right (least complex) to +/// left (most complex). This puts constants before unary operators before +/// binary operators. +/// +/// Associative operators: +/// +/// 2. Transform: "(A op B) op C" ==> "A op (B op C)" if "B op C" simplifies. +/// 3. Transform: "A op (B op C)" ==> "(A op B) op C" if "A op B" simplifies. +/// +/// Associative and commutative operators: +/// +/// 4. Transform: "(A op B) op C" ==> "(C op A) op B" if "C op A" simplifies. +/// 5. Transform: "A op (B op C)" ==> "B op (C op A)" if "C op A" simplifies. +/// 6. Transform: "(A op C1) op (B op C2)" ==> "(A op B) op (C1 op C2)" +/// if C1 and C2 are constants. +bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) { + Instruction::BinaryOps Opcode = I.getOpcode(); + bool Changed = false; + + do { + // Order operands such that they are listed from right (least complex) to + // left (most complex). This puts constants before unary operators before + // binary operators. + if (I.isCommutative() && getComplexity(I.getOperand(0)) < + getComplexity(I.getOperand(1))) + Changed = !I.swapOperands(); + + BinaryOperator *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0)); + BinaryOperator *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1)); + + if (I.isAssociative()) { + // Transform: "(A op B) op C" ==> "A op (B op C)" if "B op C" simplifies. + if (Op0 && Op0->getOpcode() == Opcode) { + Value *A = Op0->getOperand(0); + Value *B = Op0->getOperand(1); + Value *C = I.getOperand(1); + + // Does "B op C" simplify? + if (Value *V = simplifyBinOp(Opcode, B, C, SQ.getWithInstruction(&I))) { + // It simplifies to V. Form "A op V". + replaceOperand(I, 0, A); + replaceOperand(I, 1, V); + bool IsNUW = hasNoUnsignedWrap(I) && hasNoUnsignedWrap(*Op0); + bool IsNSW = maintainNoSignedWrap(I, B, C) && hasNoSignedWrap(*Op0); + + // Conservatively clear all optional flags since they may not be + // preserved by the reassociation. Reset nsw/nuw based on the above + // analysis. + ClearSubclassDataAfterReassociation(I); + + // Note: this is only valid because SimplifyBinOp doesn't look at + // the operands to Op0. + if (IsNUW) + I.setHasNoUnsignedWrap(true); + + if (IsNSW) + I.setHasNoSignedWrap(true); + + Changed = true; + ++NumReassoc; + continue; + } + } + + // Transform: "A op (B op C)" ==> "(A op B) op C" if "A op B" simplifies. + if (Op1 && Op1->getOpcode() == Opcode) { + Value *A = I.getOperand(0); + Value *B = Op1->getOperand(0); + Value *C = Op1->getOperand(1); + + // Does "A op B" simplify? + if (Value *V = simplifyBinOp(Opcode, A, B, SQ.getWithInstruction(&I))) { + // It simplifies to V. Form "V op C". + replaceOperand(I, 0, V); + replaceOperand(I, 1, C); + // Conservatively clear the optional flags, since they may not be + // preserved by the reassociation. + ClearSubclassDataAfterReassociation(I); + Changed = true; + ++NumReassoc; + continue; + } + } + } + + if (I.isAssociative() && I.isCommutative()) { + if (simplifyAssocCastAssoc(&I, *this)) { + Changed = true; + ++NumReassoc; + continue; + } + + // Transform: "(A op B) op C" ==> "(C op A) op B" if "C op A" simplifies. + if (Op0 && Op0->getOpcode() == Opcode) { + Value *A = Op0->getOperand(0); + Value *B = Op0->getOperand(1); + Value *C = I.getOperand(1); + + // Does "C op A" simplify? + if (Value *V = simplifyBinOp(Opcode, C, A, SQ.getWithInstruction(&I))) { + // It simplifies to V. Form "V op B". + replaceOperand(I, 0, V); + replaceOperand(I, 1, B); + // Conservatively clear the optional flags, since they may not be + // preserved by the reassociation. + ClearSubclassDataAfterReassociation(I); + Changed = true; + ++NumReassoc; + continue; + } + } + + // Transform: "A op (B op C)" ==> "B op (C op A)" if "C op A" simplifies. + if (Op1 && Op1->getOpcode() == Opcode) { + Value *A = I.getOperand(0); + Value *B = Op1->getOperand(0); + Value *C = Op1->getOperand(1); + + // Does "C op A" simplify? + if (Value *V = simplifyBinOp(Opcode, C, A, SQ.getWithInstruction(&I))) { + // It simplifies to V. Form "B op V". + replaceOperand(I, 0, B); + replaceOperand(I, 1, V); + // Conservatively clear the optional flags, since they may not be + // preserved by the reassociation. + ClearSubclassDataAfterReassociation(I); + Changed = true; + ++NumReassoc; + continue; + } + } + + // Transform: "(A op C1) op (B op C2)" ==> "(A op B) op (C1 op C2)" + // if C1 and C2 are constants. + Value *A, *B; + Constant *C1, *C2; + if (Op0 && Op1 && + Op0->getOpcode() == Opcode && Op1->getOpcode() == Opcode && + match(Op0, m_OneUse(m_BinOp(m_Value(A), m_Constant(C1)))) && + match(Op1, m_OneUse(m_BinOp(m_Value(B), m_Constant(C2))))) { + bool IsNUW = hasNoUnsignedWrap(I) && + hasNoUnsignedWrap(*Op0) && + hasNoUnsignedWrap(*Op1); + BinaryOperator *NewBO = (IsNUW && Opcode == Instruction::Add) ? + BinaryOperator::CreateNUW(Opcode, A, B) : + BinaryOperator::Create(Opcode, A, B); + + if (isa<FPMathOperator>(NewBO)) { + FastMathFlags Flags = I.getFastMathFlags(); + Flags &= Op0->getFastMathFlags(); + Flags &= Op1->getFastMathFlags(); + NewBO->setFastMathFlags(Flags); + } + InsertNewInstWith(NewBO, I); + NewBO->takeName(Op1); + replaceOperand(I, 0, NewBO); + replaceOperand(I, 1, ConstantExpr::get(Opcode, C1, C2)); + // Conservatively clear the optional flags, since they may not be + // preserved by the reassociation. + ClearSubclassDataAfterReassociation(I); + if (IsNUW) + I.setHasNoUnsignedWrap(true); + + Changed = true; + continue; + } + } + + // No further simplifications. + return Changed; + } while (true); +} + +/// Return whether "X LOp (Y ROp Z)" is always equal to +/// "(X LOp Y) ROp (X LOp Z)". +static bool leftDistributesOverRight(Instruction::BinaryOps LOp, + Instruction::BinaryOps ROp) { + // X & (Y | Z) <--> (X & Y) | (X & Z) + // X & (Y ^ Z) <--> (X & Y) ^ (X & Z) + if (LOp == Instruction::And) + return ROp == Instruction::Or || ROp == Instruction::Xor; + + // X | (Y & Z) <--> (X | Y) & (X | Z) + if (LOp == Instruction::Or) + return ROp == Instruction::And; + + // X * (Y + Z) <--> (X * Y) + (X * Z) + // X * (Y - Z) <--> (X * Y) - (X * Z) + if (LOp == Instruction::Mul) + return ROp == Instruction::Add || ROp == Instruction::Sub; + + return false; +} + +/// Return whether "(X LOp Y) ROp Z" is always equal to +/// "(X ROp Z) LOp (Y ROp Z)". +static bool rightDistributesOverLeft(Instruction::BinaryOps LOp, + Instruction::BinaryOps ROp) { + if (Instruction::isCommutative(ROp)) + return leftDistributesOverRight(ROp, LOp); + + // (X {&|^} Y) >> Z <--> (X >> Z) {&|^} (Y >> Z) for all shifts. + return Instruction::isBitwiseLogicOp(LOp) && Instruction::isShift(ROp); + + // TODO: It would be nice to handle division, aka "(X + Y)/Z = X/Z + Y/Z", + // but this requires knowing that the addition does not overflow and other + // such subtleties. +} + +/// This function returns identity value for given opcode, which can be used to +/// factor patterns like (X * 2) + X ==> (X * 2) + (X * 1) ==> X * (2 + 1). +static Value *getIdentityValue(Instruction::BinaryOps Opcode, Value *V) { + if (isa<Constant>(V)) + return nullptr; + + return ConstantExpr::getBinOpIdentity(Opcode, V->getType()); +} + +/// This function predicates factorization using distributive laws. By default, +/// it just returns the 'Op' inputs. But for special-cases like +/// 'add(shl(X, 5), ...)', this function will have TopOpcode == Instruction::Add +/// and Op = shl(X, 5). The 'shl' is treated as the more general 'mul X, 32' to +/// allow more factorization opportunities. +static Instruction::BinaryOps +getBinOpsForFactorization(Instruction::BinaryOps TopOpcode, BinaryOperator *Op, + Value *&LHS, Value *&RHS) { + assert(Op && "Expected a binary operator"); + LHS = Op->getOperand(0); + RHS = Op->getOperand(1); + if (TopOpcode == Instruction::Add || TopOpcode == Instruction::Sub) { + Constant *C; + if (match(Op, m_Shl(m_Value(), m_Constant(C)))) { + // X << C --> X * (1 << C) + RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), C); + return Instruction::Mul; + } + // TODO: We can add other conversions e.g. shr => div etc. + } + return Op->getOpcode(); +} + +/// This tries to simplify binary operations by factorizing out common terms +/// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). +Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, + Instruction::BinaryOps InnerOpcode, + Value *A, Value *B, Value *C, + Value *D) { + assert(A && B && C && D && "All values must be provided"); + + Value *V = nullptr; + Value *SimplifiedInst = nullptr; + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); + + // Does "X op' Y" always equal "Y op' X"? + bool InnerCommutative = Instruction::isCommutative(InnerOpcode); + + // Does "X op' (Y op Z)" always equal "(X op' Y) op (X op' Z)"? + if (leftDistributesOverRight(InnerOpcode, TopLevelOpcode)) + // Does the instruction have the form "(A op' B) op (A op' D)" or, in the + // commutative case, "(A op' B) op (C op' A)"? + if (A == C || (InnerCommutative && A == D)) { + if (A != C) + std::swap(C, D); + // Consider forming "A op' (B op D)". + // If "B op D" simplifies then it can be formed with no cost. + V = simplifyBinOp(TopLevelOpcode, B, D, SQ.getWithInstruction(&I)); + // If "B op D" doesn't simplify then only go on if both of the existing + // operations "A op' B" and "C op' D" will be zapped as no longer used. + if (!V && LHS->hasOneUse() && RHS->hasOneUse()) + V = Builder.CreateBinOp(TopLevelOpcode, B, D, RHS->getName()); + if (V) { + SimplifiedInst = Builder.CreateBinOp(InnerOpcode, A, V); + } + } + + // Does "(X op Y) op' Z" always equal "(X op' Z) op (Y op' Z)"? + if (!SimplifiedInst && rightDistributesOverLeft(TopLevelOpcode, InnerOpcode)) + // Does the instruction have the form "(A op' B) op (C op' B)" or, in the + // commutative case, "(A op' B) op (B op' D)"? + if (B == D || (InnerCommutative && B == C)) { + if (B != D) + std::swap(C, D); + // Consider forming "(A op C) op' B". + // If "A op C" simplifies then it can be formed with no cost. + V = simplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I)); + + // If "A op C" doesn't simplify then only go on if both of the existing + // operations "A op' B" and "C op' D" will be zapped as no longer used. + if (!V && LHS->hasOneUse() && RHS->hasOneUse()) + V = Builder.CreateBinOp(TopLevelOpcode, A, C, LHS->getName()); + if (V) { + SimplifiedInst = Builder.CreateBinOp(InnerOpcode, V, B); + } + } + + if (SimplifiedInst) { + ++NumFactor; + SimplifiedInst->takeName(&I); + + // Check if we can add NSW/NUW flags to SimplifiedInst. If so, set them. + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(SimplifiedInst)) { + if (isa<OverflowingBinaryOperator>(SimplifiedInst)) { + bool HasNSW = false; + bool HasNUW = false; + if (isa<OverflowingBinaryOperator>(&I)) { + HasNSW = I.hasNoSignedWrap(); + HasNUW = I.hasNoUnsignedWrap(); + } + + if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS)) { + HasNSW &= LOBO->hasNoSignedWrap(); + HasNUW &= LOBO->hasNoUnsignedWrap(); + } + + if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS)) { + HasNSW &= ROBO->hasNoSignedWrap(); + HasNUW &= ROBO->hasNoUnsignedWrap(); + } + + if (TopLevelOpcode == Instruction::Add && + InnerOpcode == Instruction::Mul) { + // We can propagate 'nsw' if we know that + // %Y = mul nsw i16 %X, C + // %Z = add nsw i16 %Y, %X + // => + // %Z = mul nsw i16 %X, C+1 + // + // iff C+1 isn't INT_MIN + const APInt *CInt; + if (match(V, m_APInt(CInt))) { + if (!CInt->isMinSignedValue()) + BO->setHasNoSignedWrap(HasNSW); + } + + // nuw can be propagated with any constant or nuw value. + BO->setHasNoUnsignedWrap(HasNUW); + } + } + } + } + return SimplifiedInst; +} + +/// This tries to simplify binary operations which some other binary operation +/// distributes over either by factorizing out common terms +/// (eg "(A*B)+(A*C)" -> "A*(B+C)") or expanding out if this results in +/// simplifications (eg: "A & (B | C) -> (A&B) | (A&C)" if this is a win). +/// Returns the simplified value, or null if it didn't simplify. +Value *InstCombinerImpl::SimplifyUsingDistributiveLaws(BinaryOperator &I) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS); + BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS); + Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); + + { + // Factorization. + Value *A, *B, *C, *D; + Instruction::BinaryOps LHSOpcode, RHSOpcode; + if (Op0) + LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); + if (Op1) + RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); + + // The instruction has the form "(A op' B) op (C op' D)". Try to factorize + // a common term. + if (Op0 && Op1 && LHSOpcode == RHSOpcode) + if (Value *V = tryFactorization(I, LHSOpcode, A, B, C, D)) + return V; + + // The instruction has the form "(A op' B) op (C)". Try to factorize common + // term. + if (Op0) + if (Value *Ident = getIdentityValue(LHSOpcode, RHS)) + if (Value *V = tryFactorization(I, LHSOpcode, A, B, RHS, Ident)) + return V; + + // The instruction has the form "(B) op (C op' D)". Try to factorize common + // term. + if (Op1) + if (Value *Ident = getIdentityValue(RHSOpcode, LHS)) + if (Value *V = tryFactorization(I, RHSOpcode, LHS, Ident, C, D)) + return V; + } + + // Expansion. + if (Op0 && rightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { + // The instruction has the form "(A op' B) op C". See if expanding it out + // to "(A op C) op' (B op C)" results in simplifications. + Value *A = Op0->getOperand(0), *B = Op0->getOperand(1), *C = RHS; + Instruction::BinaryOps InnerOpcode = Op0->getOpcode(); // op' + + // Disable the use of undef because it's not safe to distribute undef. + auto SQDistributive = SQ.getWithInstruction(&I).getWithoutUndef(); + Value *L = simplifyBinOp(TopLevelOpcode, A, C, SQDistributive); + Value *R = simplifyBinOp(TopLevelOpcode, B, C, SQDistributive); + + // Do "A op C" and "B op C" both simplify? + if (L && R) { + // They do! Return "L op' R". + ++NumExpand; + C = Builder.CreateBinOp(InnerOpcode, L, R); + C->takeName(&I); + return C; + } + + // Does "A op C" simplify to the identity value for the inner opcode? + if (L && L == ConstantExpr::getBinOpIdentity(InnerOpcode, L->getType())) { + // They do! Return "B op C". + ++NumExpand; + C = Builder.CreateBinOp(TopLevelOpcode, B, C); + C->takeName(&I); + return C; + } + + // Does "B op C" simplify to the identity value for the inner opcode? + if (R && R == ConstantExpr::getBinOpIdentity(InnerOpcode, R->getType())) { + // They do! Return "A op C". + ++NumExpand; + C = Builder.CreateBinOp(TopLevelOpcode, A, C); + C->takeName(&I); + return C; + } + } + + if (Op1 && leftDistributesOverRight(TopLevelOpcode, Op1->getOpcode())) { + // The instruction has the form "A op (B op' C)". See if expanding it out + // to "(A op B) op' (A op C)" results in simplifications. + Value *A = LHS, *B = Op1->getOperand(0), *C = Op1->getOperand(1); + Instruction::BinaryOps InnerOpcode = Op1->getOpcode(); // op' + + // Disable the use of undef because it's not safe to distribute undef. + auto SQDistributive = SQ.getWithInstruction(&I).getWithoutUndef(); + Value *L = simplifyBinOp(TopLevelOpcode, A, B, SQDistributive); + Value *R = simplifyBinOp(TopLevelOpcode, A, C, SQDistributive); + + // Do "A op B" and "A op C" both simplify? + if (L && R) { + // They do! Return "L op' R". + ++NumExpand; + A = Builder.CreateBinOp(InnerOpcode, L, R); + A->takeName(&I); + return A; + } + + // Does "A op B" simplify to the identity value for the inner opcode? + if (L && L == ConstantExpr::getBinOpIdentity(InnerOpcode, L->getType())) { + // They do! Return "A op C". + ++NumExpand; + A = Builder.CreateBinOp(TopLevelOpcode, A, C); + A->takeName(&I); + return A; + } + + // Does "A op C" simplify to the identity value for the inner opcode? + if (R && R == ConstantExpr::getBinOpIdentity(InnerOpcode, R->getType())) { + // They do! Return "A op B". + ++NumExpand; + A = Builder.CreateBinOp(TopLevelOpcode, A, B); + A->takeName(&I); + return A; + } + } + + return SimplifySelectsFeedingBinaryOp(I, LHS, RHS); +} + +Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, + Value *LHS, + Value *RHS) { + Value *A, *B, *C, *D, *E, *F; + bool LHSIsSelect = match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))); + bool RHSIsSelect = match(RHS, m_Select(m_Value(D), m_Value(E), m_Value(F))); + if (!LHSIsSelect && !RHSIsSelect) + return nullptr; + + FastMathFlags FMF; + BuilderTy::FastMathFlagGuard Guard(Builder); + if (isa<FPMathOperator>(&I)) { + FMF = I.getFastMathFlags(); + Builder.setFastMathFlags(FMF); + } + + Instruction::BinaryOps Opcode = I.getOpcode(); + SimplifyQuery Q = SQ.getWithInstruction(&I); + + Value *Cond, *True = nullptr, *False = nullptr; + if (LHSIsSelect && RHSIsSelect && A == D) { + // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F) + Cond = A; + True = simplifyBinOp(Opcode, B, E, FMF, Q); + False = simplifyBinOp(Opcode, C, F, FMF, Q); + + if (LHS->hasOneUse() && RHS->hasOneUse()) { + if (False && !True) + True = Builder.CreateBinOp(Opcode, B, E); + else if (True && !False) + False = Builder.CreateBinOp(Opcode, C, F); + } + } else if (LHSIsSelect && LHS->hasOneUse()) { + // (A ? B : C) op Y -> A ? (B op Y) : (C op Y) + Cond = A; + True = simplifyBinOp(Opcode, B, RHS, FMF, Q); + False = simplifyBinOp(Opcode, C, RHS, FMF, Q); + } else if (RHSIsSelect && RHS->hasOneUse()) { + // X op (D ? E : F) -> D ? (X op E) : (X op F) + Cond = D; + True = simplifyBinOp(Opcode, LHS, E, FMF, Q); + False = simplifyBinOp(Opcode, LHS, F, FMF, Q); + } + + if (!True || !False) + return nullptr; + + Value *SI = Builder.CreateSelect(Cond, True, False); + SI->takeName(&I); + return SI; +} + +/// Freely adapt every user of V as-if V was changed to !V. +/// WARNING: only if canFreelyInvertAllUsersOf() said this can be done. +void InstCombinerImpl::freelyInvertAllUsersOf(Value *I) { + for (User *U : I->users()) { + switch (cast<Instruction>(U)->getOpcode()) { + case Instruction::Select: { + auto *SI = cast<SelectInst>(U); + SI->swapValues(); + SI->swapProfMetadata(); + break; + } + case Instruction::Br: + cast<BranchInst>(U)->swapSuccessors(); // swaps prof metadata too + break; + case Instruction::Xor: + replaceInstUsesWith(cast<Instruction>(*U), I); + break; + default: + llvm_unreachable("Got unexpected user - out of sync with " + "canFreelyInvertAllUsersOf() ?"); + } + } +} + +/// Given a 'sub' instruction, return the RHS of the instruction if the LHS is a +/// constant zero (which is the 'negate' form). +Value *InstCombinerImpl::dyn_castNegVal(Value *V) const { + Value *NegV; + if (match(V, m_Neg(m_Value(NegV)))) + return NegV; + + // Constants can be considered to be negated values if they can be folded. + if (ConstantInt *C = dyn_cast<ConstantInt>(V)) + return ConstantExpr::getNeg(C); + + if (ConstantDataVector *C = dyn_cast<ConstantDataVector>(V)) + if (C->getType()->getElementType()->isIntegerTy()) + return ConstantExpr::getNeg(C); + + if (ConstantVector *CV = dyn_cast<ConstantVector>(V)) { + for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) { + Constant *Elt = CV->getAggregateElement(i); + if (!Elt) + return nullptr; + + if (isa<UndefValue>(Elt)) + continue; + + if (!isa<ConstantInt>(Elt)) + return nullptr; + } + return ConstantExpr::getNeg(CV); + } + + // Negate integer vector splats. + if (auto *CV = dyn_cast<Constant>(V)) + if (CV->getType()->isVectorTy() && + CV->getType()->getScalarType()->isIntegerTy() && CV->getSplatValue()) + return ConstantExpr::getNeg(CV); + + return nullptr; +} + +/// A binop with a constant operand and a sign-extended boolean operand may be +/// converted into a select of constants by applying the binary operation to +/// the constant with the two possible values of the extended boolean (0 or -1). +Instruction *InstCombinerImpl::foldBinopOfSextBoolToSelect(BinaryOperator &BO) { + // TODO: Handle non-commutative binop (constant is operand 0). + // TODO: Handle zext. + // TODO: Peek through 'not' of cast. + Value *BO0 = BO.getOperand(0); + Value *BO1 = BO.getOperand(1); + Value *X; + Constant *C; + if (!match(BO0, m_SExt(m_Value(X))) || !match(BO1, m_ImmConstant(C)) || + !X->getType()->isIntOrIntVectorTy(1)) + return nullptr; + + // bo (sext i1 X), C --> select X, (bo -1, C), (bo 0, C) + Constant *Ones = ConstantInt::getAllOnesValue(BO.getType()); + Constant *Zero = ConstantInt::getNullValue(BO.getType()); + Value *TVal = Builder.CreateBinOp(BO.getOpcode(), Ones, C); + Value *FVal = Builder.CreateBinOp(BO.getOpcode(), Zero, C); + return SelectInst::Create(X, TVal, FVal); +} + +static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, + InstCombiner::BuilderTy &Builder) { + if (auto *Cast = dyn_cast<CastInst>(&I)) + return Builder.CreateCast(Cast->getOpcode(), SO, I.getType()); + + if (auto *II = dyn_cast<IntrinsicInst>(&I)) { + assert(canConstantFoldCallTo(II, cast<Function>(II->getCalledOperand())) && + "Expected constant-foldable intrinsic"); + Intrinsic::ID IID = II->getIntrinsicID(); + if (II->arg_size() == 1) + return Builder.CreateUnaryIntrinsic(IID, SO); + + // This works for real binary ops like min/max (where we always expect the + // constant operand to be canonicalized as op1) and unary ops with a bonus + // constant argument like ctlz/cttz. + // TODO: Handle non-commutative binary intrinsics as below for binops. + assert(II->arg_size() == 2 && "Expected binary intrinsic"); + assert(isa<Constant>(II->getArgOperand(1)) && "Expected constant operand"); + return Builder.CreateBinaryIntrinsic(IID, SO, II->getArgOperand(1)); + } + + assert(I.isBinaryOp() && "Unexpected opcode for select folding"); + + // Figure out if the constant is the left or the right argument. + bool ConstIsRHS = isa<Constant>(I.getOperand(1)); + Constant *ConstOperand = cast<Constant>(I.getOperand(ConstIsRHS)); + + Value *Op0 = SO, *Op1 = ConstOperand; + if (!ConstIsRHS) + std::swap(Op0, Op1); + + Value *NewBO = Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), Op0, + Op1, SO->getName() + ".op"); + if (auto *NewBOI = dyn_cast<Instruction>(NewBO)) + NewBOI->copyIRFlags(&I); + return NewBO; +} + +Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, + bool FoldWithMultiUse) { + // Don't modify shared select instructions unless set FoldWithMultiUse + if (!SI->hasOneUse() && !FoldWithMultiUse) + return nullptr; + + Value *TV = SI->getTrueValue(); + Value *FV = SI->getFalseValue(); + if (!(isa<Constant>(TV) || isa<Constant>(FV))) + return nullptr; + + // Bool selects with constant operands can be folded to logical ops. + if (SI->getType()->isIntOrIntVectorTy(1)) + return nullptr; + + // If it's a bitcast involving vectors, make sure it has the same number of + // elements on both sides. + if (auto *BC = dyn_cast<BitCastInst>(&Op)) { + VectorType *DestTy = dyn_cast<VectorType>(BC->getDestTy()); + VectorType *SrcTy = dyn_cast<VectorType>(BC->getSrcTy()); + + // Verify that either both or neither are vectors. + if ((SrcTy == nullptr) != (DestTy == nullptr)) + return nullptr; + + // If vectors, verify that they have the same number of elements. + if (SrcTy && SrcTy->getElementCount() != DestTy->getElementCount()) + return nullptr; + } + + // Test if a CmpInst instruction is used exclusively by a select as + // part of a minimum or maximum operation. If so, refrain from doing + // any other folding. This helps out other analyses which understand + // non-obfuscated minimum and maximum idioms, such as ScalarEvolution + // and CodeGen. And in this case, at least one of the comparison + // operands has at least one user besides the compare (the select), + // which would often largely negate the benefit of folding anyway. + if (auto *CI = dyn_cast<CmpInst>(SI->getCondition())) { + if (CI->hasOneUse()) { + Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1); + + // FIXME: This is a hack to avoid infinite looping with min/max patterns. + // We have to ensure that vector constants that only differ with + // undef elements are treated as equivalent. + auto areLooselyEqual = [](Value *A, Value *B) { + if (A == B) + return true; + + // Test for vector constants. + Constant *ConstA, *ConstB; + if (!match(A, m_Constant(ConstA)) || !match(B, m_Constant(ConstB))) + return false; + + // TODO: Deal with FP constants? + if (!A->getType()->isIntOrIntVectorTy() || A->getType() != B->getType()) + return false; + + // Compare for equality including undefs as equal. + auto *Cmp = ConstantExpr::getCompare(ICmpInst::ICMP_EQ, ConstA, ConstB); + const APInt *C; + return match(Cmp, m_APIntAllowUndef(C)) && C->isOne(); + }; + + if ((areLooselyEqual(TV, Op0) && areLooselyEqual(FV, Op1)) || + (areLooselyEqual(FV, Op0) && areLooselyEqual(TV, Op1))) + return nullptr; + } + } + + Value *NewTV = foldOperationIntoSelectOperand(Op, TV, Builder); + Value *NewFV = foldOperationIntoSelectOperand(Op, FV, Builder); + return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); +} + +static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV, + InstCombiner::BuilderTy &Builder) { + bool ConstIsRHS = isa<Constant>(I->getOperand(1)); + Constant *C = cast<Constant>(I->getOperand(ConstIsRHS)); + + Value *Op0 = InV, *Op1 = C; + if (!ConstIsRHS) + std::swap(Op0, Op1); + + Value *RI = Builder.CreateBinOp(I->getOpcode(), Op0, Op1, "phi.bo"); + auto *FPInst = dyn_cast<Instruction>(RI); + if (FPInst && isa<FPMathOperator>(FPInst)) + FPInst->copyFastMathFlags(I); + return RI; +} + +Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { + unsigned NumPHIValues = PN->getNumIncomingValues(); + if (NumPHIValues == 0) + return nullptr; + + // We normally only transform phis with a single use. However, if a PHI has + // multiple uses and they are all the same operation, we can fold *all* of the + // uses into the PHI. + if (!PN->hasOneUse()) { + // Walk the use list for the instruction, comparing them to I. + for (User *U : PN->users()) { + Instruction *UI = cast<Instruction>(U); + if (UI != &I && !I.isIdenticalTo(UI)) + return nullptr; + } + // Otherwise, we can replace *all* users with the new PHI we form. + } + + // Check to see if all of the operands of the PHI are simple constants + // (constantint/constantfp/undef). If there is one non-constant value, + // remember the BB it is in. If there is more than one or if *it* is a PHI, + // bail out. We don't do arbitrary constant expressions here because moving + // their computation can be expensive without a cost model. + BasicBlock *NonConstBB = nullptr; + for (unsigned i = 0; i != NumPHIValues; ++i) { + Value *InVal = PN->getIncomingValue(i); + // For non-freeze, require constant operand + // For freeze, require non-undef, non-poison operand + if (!isa<FreezeInst>(I) && match(InVal, m_ImmConstant())) + continue; + if (isa<FreezeInst>(I) && isGuaranteedNotToBeUndefOrPoison(InVal)) + continue; + + if (isa<PHINode>(InVal)) return nullptr; // Itself a phi. + if (NonConstBB) return nullptr; // More than one non-const value. + + NonConstBB = PN->getIncomingBlock(i); + + // If the InVal is an invoke at the end of the pred block, then we can't + // insert a computation after it without breaking the edge. + if (isa<InvokeInst>(InVal)) + if (cast<Instruction>(InVal)->getParent() == NonConstBB) + return nullptr; + + // If the incoming non-constant value is reachable from the phis block, + // we'll push the operation across a loop backedge. This could result in + // an infinite combine loop, and is generally non-profitable (especially + // if the operation was originally outside the loop). + if (isPotentiallyReachable(PN->getParent(), NonConstBB, nullptr, &DT, LI)) + return nullptr; + } + + // If there is exactly one non-constant value, we can insert a copy of the + // operation in that block. However, if this is a critical edge, we would be + // inserting the computation on some other paths (e.g. inside a loop). Only + // do this if the pred block is unconditionally branching into the phi block. + // Also, make sure that the pred block is not dead code. + if (NonConstBB != nullptr) { + BranchInst *BI = dyn_cast<BranchInst>(NonConstBB->getTerminator()); + if (!BI || !BI->isUnconditional() || !DT.isReachableFromEntry(NonConstBB)) + return nullptr; + } + + // Okay, we can do the transformation: create the new PHI node. + PHINode *NewPN = PHINode::Create(I.getType(), PN->getNumIncomingValues()); + InsertNewInstBefore(NewPN, *PN); + NewPN->takeName(PN); + + // If we are going to have to insert a new computation, do so right before the + // predecessor's terminator. + if (NonConstBB) + Builder.SetInsertPoint(NonConstBB->getTerminator()); + + // Next, add all of the operands to the PHI. + if (SelectInst *SI = dyn_cast<SelectInst>(&I)) { + // We only currently try to fold the condition of a select when it is a phi, + // not the true/false values. + Value *TrueV = SI->getTrueValue(); + Value *FalseV = SI->getFalseValue(); + BasicBlock *PhiTransBB = PN->getParent(); + for (unsigned i = 0; i != NumPHIValues; ++i) { + BasicBlock *ThisBB = PN->getIncomingBlock(i); + Value *TrueVInPred = TrueV->DoPHITranslation(PhiTransBB, ThisBB); + Value *FalseVInPred = FalseV->DoPHITranslation(PhiTransBB, ThisBB); + Value *InV = nullptr; + // Beware of ConstantExpr: it may eventually evaluate to getNullValue, + // even if currently isNullValue gives false. + Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i)); + // For vector constants, we cannot use isNullValue to fold into + // FalseVInPred versus TrueVInPred. When we have individual nonzero + // elements in the vector, we will incorrectly fold InC to + // `TrueVInPred`. + if (InC && isa<ConstantInt>(InC)) + InV = InC->isNullValue() ? FalseVInPred : TrueVInPred; + else { + // Generate the select in the same block as PN's current incoming block. + // Note: ThisBB need not be the NonConstBB because vector constants + // which are constants by definition are handled here. + // FIXME: This can lead to an increase in IR generation because we might + // generate selects for vector constant phi operand, that could not be + // folded to TrueVInPred or FalseVInPred as done for ConstantInt. For + // non-vector phis, this transformation was always profitable because + // the select would be generated exactly once in the NonConstBB. + Builder.SetInsertPoint(ThisBB->getTerminator()); + InV = Builder.CreateSelect(PN->getIncomingValue(i), TrueVInPred, + FalseVInPred, "phi.sel"); + } + NewPN->addIncoming(InV, ThisBB); + } + } else if (CmpInst *CI = dyn_cast<CmpInst>(&I)) { + Constant *C = cast<Constant>(I.getOperand(1)); + for (unsigned i = 0; i != NumPHIValues; ++i) { + Value *InV = nullptr; + if (auto *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) + InV = ConstantExpr::getCompare(CI->getPredicate(), InC, C); + else + InV = Builder.CreateCmp(CI->getPredicate(), PN->getIncomingValue(i), + C, "phi.cmp"); + NewPN->addIncoming(InV, PN->getIncomingBlock(i)); + } + } else if (auto *BO = dyn_cast<BinaryOperator>(&I)) { + for (unsigned i = 0; i != NumPHIValues; ++i) { + Value *InV = foldOperationIntoPhiValue(BO, PN->getIncomingValue(i), + Builder); + NewPN->addIncoming(InV, PN->getIncomingBlock(i)); + } + } else if (isa<FreezeInst>(&I)) { + for (unsigned i = 0; i != NumPHIValues; ++i) { + Value *InV; + if (NonConstBB == PN->getIncomingBlock(i)) + InV = Builder.CreateFreeze(PN->getIncomingValue(i), "phi.fr"); + else + InV = PN->getIncomingValue(i); + NewPN->addIncoming(InV, PN->getIncomingBlock(i)); + } + } else { + CastInst *CI = cast<CastInst>(&I); + Type *RetTy = CI->getType(); + for (unsigned i = 0; i != NumPHIValues; ++i) { + Value *InV; + if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) + InV = ConstantExpr::getCast(CI->getOpcode(), InC, RetTy); + else + InV = Builder.CreateCast(CI->getOpcode(), PN->getIncomingValue(i), + I.getType(), "phi.cast"); + NewPN->addIncoming(InV, PN->getIncomingBlock(i)); + } + } + + for (User *U : make_early_inc_range(PN->users())) { + Instruction *User = cast<Instruction>(U); + if (User == &I) continue; + replaceInstUsesWith(*User, NewPN); + eraseInstFromFunction(*User); + } + return replaceInstUsesWith(I, NewPN); +} + +Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) { + // TODO: This should be similar to the incoming values check in foldOpIntoPhi: + // we are guarding against replicating the binop in >1 predecessor. + // This could miss matching a phi with 2 constant incoming values. + auto *Phi0 = dyn_cast<PHINode>(BO.getOperand(0)); + auto *Phi1 = dyn_cast<PHINode>(BO.getOperand(1)); + if (!Phi0 || !Phi1 || !Phi0->hasOneUse() || !Phi1->hasOneUse() || + Phi0->getNumOperands() != 2 || Phi1->getNumOperands() != 2) + return nullptr; + + // TODO: Remove the restriction for binop being in the same block as the phis. + if (BO.getParent() != Phi0->getParent() || + BO.getParent() != Phi1->getParent()) + return nullptr; + + // Match a pair of incoming constants for one of the predecessor blocks. + BasicBlock *ConstBB, *OtherBB; + Constant *C0, *C1; + if (match(Phi0->getIncomingValue(0), m_ImmConstant(C0))) { + ConstBB = Phi0->getIncomingBlock(0); + OtherBB = Phi0->getIncomingBlock(1); + } else if (match(Phi0->getIncomingValue(1), m_ImmConstant(C0))) { + ConstBB = Phi0->getIncomingBlock(1); + OtherBB = Phi0->getIncomingBlock(0); + } else { + return nullptr; + } + if (!match(Phi1->getIncomingValueForBlock(ConstBB), m_ImmConstant(C1))) + return nullptr; + + // The block that we are hoisting to must reach here unconditionally. + // Otherwise, we could be speculatively executing an expensive or + // non-speculative op. + auto *PredBlockBranch = dyn_cast<BranchInst>(OtherBB->getTerminator()); + if (!PredBlockBranch || PredBlockBranch->isConditional() || + !DT.isReachableFromEntry(OtherBB)) + return nullptr; + + // TODO: This check could be tightened to only apply to binops (div/rem) that + // are not safe to speculatively execute. But that could allow hoisting + // potentially expensive instructions (fdiv for example). + for (auto BBIter = BO.getParent()->begin(); &*BBIter != &BO; ++BBIter) + if (!isGuaranteedToTransferExecutionToSuccessor(&*BBIter)) + return nullptr; + + // Make a new binop in the predecessor block with the non-constant incoming + // values. + Builder.SetInsertPoint(PredBlockBranch); + Value *NewBO = Builder.CreateBinOp(BO.getOpcode(), + Phi0->getIncomingValueForBlock(OtherBB), + Phi1->getIncomingValueForBlock(OtherBB)); + if (auto *NotFoldedNewBO = dyn_cast<BinaryOperator>(NewBO)) + NotFoldedNewBO->copyIRFlags(&BO); + + // Fold constants for the predecessor block with constant incoming values. + Constant *NewC = ConstantExpr::get(BO.getOpcode(), C0, C1); + + // Replace the binop with a phi of the new values. The old phis are dead. + PHINode *NewPhi = PHINode::Create(BO.getType(), 2); + NewPhi->addIncoming(NewBO, OtherBB); + NewPhi->addIncoming(NewC, ConstBB); + return NewPhi; +} + +Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { + if (!isa<Constant>(I.getOperand(1))) + return nullptr; + + if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) { + if (Instruction *NewSel = FoldOpIntoSelect(I, Sel)) + return NewSel; + } else if (auto *PN = dyn_cast<PHINode>(I.getOperand(0))) { + if (Instruction *NewPhi = foldOpIntoPhi(I, PN)) + return NewPhi; + } + return nullptr; +} + +/// Given a pointer type and a constant offset, determine whether or not there +/// is a sequence of GEP indices into the pointed type that will land us at the +/// specified offset. If so, fill them into NewIndices and return the resultant +/// element type, otherwise return null. +static Type *findElementAtOffset(PointerType *PtrTy, int64_t IntOffset, + SmallVectorImpl<Value *> &NewIndices, + const DataLayout &DL) { + // Only used by visitGEPOfBitcast(), which is skipped for opaque pointers. + Type *Ty = PtrTy->getNonOpaquePointerElementType(); + if (!Ty->isSized()) + return nullptr; + + APInt Offset(DL.getIndexTypeSizeInBits(PtrTy), IntOffset); + SmallVector<APInt> Indices = DL.getGEPIndicesForOffset(Ty, Offset); + if (!Offset.isZero()) + return nullptr; + + for (const APInt &Index : Indices) + NewIndices.push_back(ConstantInt::get(PtrTy->getContext(), Index)); + return Ty; +} + +static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) { + // If this GEP has only 0 indices, it is the same pointer as + // Src. If Src is not a trivial GEP too, don't combine + // the indices. + if (GEP.hasAllZeroIndices() && !Src.hasAllZeroIndices() && + !Src.hasOneUse()) + return false; + return true; +} + +/// Return a value X such that Val = X * Scale, or null if none. +/// If the multiplication is known not to overflow, then NoSignedWrap is set. +Value *InstCombinerImpl::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { + assert(isa<IntegerType>(Val->getType()) && "Can only descale integers!"); + assert(cast<IntegerType>(Val->getType())->getBitWidth() == + Scale.getBitWidth() && "Scale not compatible with value!"); + + // If Val is zero or Scale is one then Val = Val * Scale. + if (match(Val, m_Zero()) || Scale == 1) { + NoSignedWrap = true; + return Val; + } + + // If Scale is zero then it does not divide Val. + if (Scale.isMinValue()) + return nullptr; + + // Look through chains of multiplications, searching for a constant that is + // divisible by Scale. For example, descaling X*(Y*(Z*4)) by a factor of 4 + // will find the constant factor 4 and produce X*(Y*Z). Descaling X*(Y*8) by + // a factor of 4 will produce X*(Y*2). The principle of operation is to bore + // down from Val: + // + // Val = M1 * X || Analysis starts here and works down + // M1 = M2 * Y || Doesn't descend into terms with more + // M2 = Z * 4 \/ than one use + // + // Then to modify a term at the bottom: + // + // Val = M1 * X + // M1 = Z * Y || Replaced M2 with Z + // + // Then to work back up correcting nsw flags. + + // Op - the term we are currently analyzing. Starts at Val then drills down. + // Replaced with its descaled value before exiting from the drill down loop. + Value *Op = Val; + + // Parent - initially null, but after drilling down notes where Op came from. + // In the example above, Parent is (Val, 0) when Op is M1, because M1 is the + // 0'th operand of Val. + std::pair<Instruction *, unsigned> Parent; + + // Set if the transform requires a descaling at deeper levels that doesn't + // overflow. + bool RequireNoSignedWrap = false; + + // Log base 2 of the scale. Negative if not a power of 2. + int32_t logScale = Scale.exactLogBase2(); + + for (;; Op = Parent.first->getOperand(Parent.second)) { // Drill down + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { + // If Op is a constant divisible by Scale then descale to the quotient. + APInt Quotient(Scale), Remainder(Scale); // Init ensures right bitwidth. + APInt::sdivrem(CI->getValue(), Scale, Quotient, Remainder); + if (!Remainder.isMinValue()) + // Not divisible by Scale. + return nullptr; + // Replace with the quotient in the parent. + Op = ConstantInt::get(CI->getType(), Quotient); + NoSignedWrap = true; + break; + } + + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) { + if (BO->getOpcode() == Instruction::Mul) { + // Multiplication. + NoSignedWrap = BO->hasNoSignedWrap(); + if (RequireNoSignedWrap && !NoSignedWrap) + return nullptr; + + // There are three cases for multiplication: multiplication by exactly + // the scale, multiplication by a constant different to the scale, and + // multiplication by something else. + Value *LHS = BO->getOperand(0); + Value *RHS = BO->getOperand(1); + + if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + // Multiplication by a constant. + if (CI->getValue() == Scale) { + // Multiplication by exactly the scale, replace the multiplication + // by its left-hand side in the parent. + Op = LHS; + break; + } + + // Otherwise drill down into the constant. + if (!Op->hasOneUse()) + return nullptr; + + Parent = std::make_pair(BO, 1); + continue; + } + + // Multiplication by something else. Drill down into the left-hand side + // since that's where the reassociate pass puts the good stuff. + if (!Op->hasOneUse()) + return nullptr; + + Parent = std::make_pair(BO, 0); + continue; + } + + if (logScale > 0 && BO->getOpcode() == Instruction::Shl && + isa<ConstantInt>(BO->getOperand(1))) { + // Multiplication by a power of 2. + NoSignedWrap = BO->hasNoSignedWrap(); + if (RequireNoSignedWrap && !NoSignedWrap) + return nullptr; + + Value *LHS = BO->getOperand(0); + int32_t Amt = cast<ConstantInt>(BO->getOperand(1))-> + getLimitedValue(Scale.getBitWidth()); + // Op = LHS << Amt. + + if (Amt == logScale) { + // Multiplication by exactly the scale, replace the multiplication + // by its left-hand side in the parent. + Op = LHS; + break; + } + if (Amt < logScale || !Op->hasOneUse()) + return nullptr; + + // Multiplication by more than the scale. Reduce the multiplying amount + // by the scale in the parent. + Parent = std::make_pair(BO, 1); + Op = ConstantInt::get(BO->getType(), Amt - logScale); + break; + } + } + + if (!Op->hasOneUse()) + return nullptr; + + if (CastInst *Cast = dyn_cast<CastInst>(Op)) { + if (Cast->getOpcode() == Instruction::SExt) { + // Op is sign-extended from a smaller type, descale in the smaller type. + unsigned SmallSize = Cast->getSrcTy()->getPrimitiveSizeInBits(); + APInt SmallScale = Scale.trunc(SmallSize); + // Suppose Op = sext X, and we descale X as Y * SmallScale. We want to + // descale Op as (sext Y) * Scale. In order to have + // sext (Y * SmallScale) = (sext Y) * Scale + // some conditions need to hold however: SmallScale must sign-extend to + // Scale and the multiplication Y * SmallScale should not overflow. + if (SmallScale.sext(Scale.getBitWidth()) != Scale) + // SmallScale does not sign-extend to Scale. + return nullptr; + assert(SmallScale.exactLogBase2() == logScale); + // Require that Y * SmallScale must not overflow. + RequireNoSignedWrap = true; + + // Drill down through the cast. + Parent = std::make_pair(Cast, 0); + Scale = SmallScale; + continue; + } + + if (Cast->getOpcode() == Instruction::Trunc) { + // Op is truncated from a larger type, descale in the larger type. + // Suppose Op = trunc X, and we descale X as Y * sext Scale. Then + // trunc (Y * sext Scale) = (trunc Y) * Scale + // always holds. However (trunc Y) * Scale may overflow even if + // trunc (Y * sext Scale) does not, so nsw flags need to be cleared + // from this point up in the expression (see later). + if (RequireNoSignedWrap) + return nullptr; + + // Drill down through the cast. + unsigned LargeSize = Cast->getSrcTy()->getPrimitiveSizeInBits(); + Parent = std::make_pair(Cast, 0); + Scale = Scale.sext(LargeSize); + if (logScale + 1 == (int32_t)Cast->getType()->getPrimitiveSizeInBits()) + logScale = -1; + assert(Scale.exactLogBase2() == logScale); + continue; + } + } + + // Unsupported expression, bail out. + return nullptr; + } + + // If Op is zero then Val = Op * Scale. + if (match(Op, m_Zero())) { + NoSignedWrap = true; + return Op; + } + + // We know that we can successfully descale, so from here on we can safely + // modify the IR. Op holds the descaled version of the deepest term in the + // expression. NoSignedWrap is 'true' if multiplying Op by Scale is known + // not to overflow. + + if (!Parent.first) + // The expression only had one term. + return Op; + + // Rewrite the parent using the descaled version of its operand. + assert(Parent.first->hasOneUse() && "Drilled down when more than one use!"); + assert(Op != Parent.first->getOperand(Parent.second) && + "Descaling was a no-op?"); + replaceOperand(*Parent.first, Parent.second, Op); + Worklist.push(Parent.first); + + // Now work back up the expression correcting nsw flags. The logic is based + // on the following observation: if X * Y is known not to overflow as a signed + // multiplication, and Y is replaced by a value Z with smaller absolute value, + // then X * Z will not overflow as a signed multiplication either. As we work + // our way up, having NoSignedWrap 'true' means that the descaled value at the + // current level has strictly smaller absolute value than the original. + Instruction *Ancestor = Parent.first; + do { + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Ancestor)) { + // If the multiplication wasn't nsw then we can't say anything about the + // value of the descaled multiplication, and we have to clear nsw flags + // from this point on up. + bool OpNoSignedWrap = BO->hasNoSignedWrap(); + NoSignedWrap &= OpNoSignedWrap; + if (NoSignedWrap != OpNoSignedWrap) { + BO->setHasNoSignedWrap(NoSignedWrap); + Worklist.push(Ancestor); + } + } else if (Ancestor->getOpcode() == Instruction::Trunc) { + // The fact that the descaled input to the trunc has smaller absolute + // value than the original input doesn't tell us anything useful about + // the absolute values of the truncations. + NoSignedWrap = false; + } + assert((Ancestor->getOpcode() != Instruction::SExt || NoSignedWrap) && + "Failed to keep proper track of nsw flags while drilling down?"); + + if (Ancestor == Val) + // Got to the top, all done! + return Val; + + // Move up one level in the expression. + assert(Ancestor->hasOneUse() && "Drilled down when more than one use!"); + Ancestor = Ancestor->user_back(); + } while (true); +} + +Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { + if (!isa<VectorType>(Inst.getType())) + return nullptr; + + BinaryOperator::BinaryOps Opcode = Inst.getOpcode(); + Value *LHS = Inst.getOperand(0), *RHS = Inst.getOperand(1); + assert(cast<VectorType>(LHS->getType())->getElementCount() == + cast<VectorType>(Inst.getType())->getElementCount()); + assert(cast<VectorType>(RHS->getType())->getElementCount() == + cast<VectorType>(Inst.getType())->getElementCount()); + + // If both operands of the binop are vector concatenations, then perform the + // narrow binop on each pair of the source operands followed by concatenation + // of the results. + Value *L0, *L1, *R0, *R1; + ArrayRef<int> Mask; + if (match(LHS, m_Shuffle(m_Value(L0), m_Value(L1), m_Mask(Mask))) && + match(RHS, m_Shuffle(m_Value(R0), m_Value(R1), m_SpecificMask(Mask))) && + LHS->hasOneUse() && RHS->hasOneUse() && + cast<ShuffleVectorInst>(LHS)->isConcat() && + cast<ShuffleVectorInst>(RHS)->isConcat()) { + // This transform does not have the speculative execution constraint as + // below because the shuffle is a concatenation. The new binops are + // operating on exactly the same elements as the existing binop. + // TODO: We could ease the mask requirement to allow different undef lanes, + // but that requires an analysis of the binop-with-undef output value. + Value *NewBO0 = Builder.CreateBinOp(Opcode, L0, R0); + if (auto *BO = dyn_cast<BinaryOperator>(NewBO0)) + BO->copyIRFlags(&Inst); + Value *NewBO1 = Builder.CreateBinOp(Opcode, L1, R1); + if (auto *BO = dyn_cast<BinaryOperator>(NewBO1)) + BO->copyIRFlags(&Inst); + return new ShuffleVectorInst(NewBO0, NewBO1, Mask); + } + + // It may not be safe to reorder shuffles and things like div, urem, etc. + // because we may trap when executing those ops on unknown vector elements. + // See PR20059. + if (!isSafeToSpeculativelyExecute(&Inst)) + return nullptr; + + auto createBinOpShuffle = [&](Value *X, Value *Y, ArrayRef<int> M) { + Value *XY = Builder.CreateBinOp(Opcode, X, Y); + if (auto *BO = dyn_cast<BinaryOperator>(XY)) + BO->copyIRFlags(&Inst); + return new ShuffleVectorInst(XY, M); + }; + + // If both arguments of the binary operation are shuffles that use the same + // mask and shuffle within a single vector, move the shuffle after the binop. + Value *V1, *V2; + if (match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(Mask))) && + match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(Mask))) && + V1->getType() == V2->getType() && + (LHS->hasOneUse() || RHS->hasOneUse() || LHS == RHS)) { + // Op(shuffle(V1, Mask), shuffle(V2, Mask)) -> shuffle(Op(V1, V2), Mask) + return createBinOpShuffle(V1, V2, Mask); + } + + // If both arguments of a commutative binop are select-shuffles that use the + // same mask with commuted operands, the shuffles are unnecessary. + if (Inst.isCommutative() && + match(LHS, m_Shuffle(m_Value(V1), m_Value(V2), m_Mask(Mask))) && + match(RHS, + m_Shuffle(m_Specific(V2), m_Specific(V1), m_SpecificMask(Mask)))) { + auto *LShuf = cast<ShuffleVectorInst>(LHS); + auto *RShuf = cast<ShuffleVectorInst>(RHS); + // TODO: Allow shuffles that contain undefs in the mask? + // That is legal, but it reduces undef knowledge. + // TODO: Allow arbitrary shuffles by shuffling after binop? + // That might be legal, but we have to deal with poison. + if (LShuf->isSelect() && + !is_contained(LShuf->getShuffleMask(), UndefMaskElem) && + RShuf->isSelect() && + !is_contained(RShuf->getShuffleMask(), UndefMaskElem)) { + // Example: + // LHS = shuffle V1, V2, <0, 5, 6, 3> + // RHS = shuffle V2, V1, <0, 5, 6, 3> + // LHS + RHS --> (V10+V20, V21+V11, V22+V12, V13+V23) --> V1 + V2 + Instruction *NewBO = BinaryOperator::Create(Opcode, V1, V2); + NewBO->copyIRFlags(&Inst); + return NewBO; + } + } + + // If one argument is a shuffle within one vector and the other is a constant, + // try moving the shuffle after the binary operation. This canonicalization + // intends to move shuffles closer to other shuffles and binops closer to + // other binops, so they can be folded. It may also enable demanded elements + // transforms. + Constant *C; + auto *InstVTy = dyn_cast<FixedVectorType>(Inst.getType()); + if (InstVTy && + match(&Inst, + m_c_BinOp(m_OneUse(m_Shuffle(m_Value(V1), m_Undef(), m_Mask(Mask))), + m_ImmConstant(C))) && + cast<FixedVectorType>(V1->getType())->getNumElements() <= + InstVTy->getNumElements()) { + assert(InstVTy->getScalarType() == V1->getType()->getScalarType() && + "Shuffle should not change scalar type"); + + // Find constant NewC that has property: + // shuffle(NewC, ShMask) = C + // If such constant does not exist (example: ShMask=<0,0> and C=<1,2>) + // reorder is not possible. A 1-to-1 mapping is not required. Example: + // ShMask = <1,1,2,2> and C = <5,5,6,6> --> NewC = <undef,5,6,undef> + bool ConstOp1 = isa<Constant>(RHS); + ArrayRef<int> ShMask = Mask; + unsigned SrcVecNumElts = + cast<FixedVectorType>(V1->getType())->getNumElements(); + UndefValue *UndefScalar = UndefValue::get(C->getType()->getScalarType()); + SmallVector<Constant *, 16> NewVecC(SrcVecNumElts, UndefScalar); + bool MayChange = true; + unsigned NumElts = InstVTy->getNumElements(); + for (unsigned I = 0; I < NumElts; ++I) { + Constant *CElt = C->getAggregateElement(I); + if (ShMask[I] >= 0) { + assert(ShMask[I] < (int)NumElts && "Not expecting narrowing shuffle"); + Constant *NewCElt = NewVecC[ShMask[I]]; + // Bail out if: + // 1. The constant vector contains a constant expression. + // 2. The shuffle needs an element of the constant vector that can't + // be mapped to a new constant vector. + // 3. This is a widening shuffle that copies elements of V1 into the + // extended elements (extending with undef is allowed). + if (!CElt || (!isa<UndefValue>(NewCElt) && NewCElt != CElt) || + I >= SrcVecNumElts) { + MayChange = false; + break; + } + NewVecC[ShMask[I]] = CElt; + } + // If this is a widening shuffle, we must be able to extend with undef + // elements. If the original binop does not produce an undef in the high + // lanes, then this transform is not safe. + // Similarly for undef lanes due to the shuffle mask, we can only + // transform binops that preserve undef. + // TODO: We could shuffle those non-undef constant values into the + // result by using a constant vector (rather than an undef vector) + // as operand 1 of the new binop, but that might be too aggressive + // for target-independent shuffle creation. + if (I >= SrcVecNumElts || ShMask[I] < 0) { + Constant *MaybeUndef = + ConstOp1 ? ConstantExpr::get(Opcode, UndefScalar, CElt) + : ConstantExpr::get(Opcode, CElt, UndefScalar); + if (!match(MaybeUndef, m_Undef())) { + MayChange = false; + break; + } + } + } + if (MayChange) { + Constant *NewC = ConstantVector::get(NewVecC); + // It may not be safe to execute a binop on a vector with undef elements + // because the entire instruction can be folded to undef or create poison + // that did not exist in the original code. + if (Inst.isIntDivRem() || (Inst.isShift() && ConstOp1)) + NewC = getSafeVectorConstantForBinop(Opcode, NewC, ConstOp1); + + // Op(shuffle(V1, Mask), C) -> shuffle(Op(V1, NewC), Mask) + // Op(C, shuffle(V1, Mask)) -> shuffle(Op(NewC, V1), Mask) + Value *NewLHS = ConstOp1 ? V1 : NewC; + Value *NewRHS = ConstOp1 ? NewC : V1; + return createBinOpShuffle(NewLHS, NewRHS, Mask); + } + } + + // Try to reassociate to sink a splat shuffle after a binary operation. + if (Inst.isAssociative() && Inst.isCommutative()) { + // Canonicalize shuffle operand as LHS. + if (isa<ShuffleVectorInst>(RHS)) + std::swap(LHS, RHS); + + Value *X; + ArrayRef<int> MaskC; + int SplatIndex; + Value *Y, *OtherOp; + if (!match(LHS, + m_OneUse(m_Shuffle(m_Value(X), m_Undef(), m_Mask(MaskC)))) || + !match(MaskC, m_SplatOrUndefMask(SplatIndex)) || + X->getType() != Inst.getType() || + !match(RHS, m_OneUse(m_BinOp(Opcode, m_Value(Y), m_Value(OtherOp))))) + return nullptr; + + // FIXME: This may not be safe if the analysis allows undef elements. By + // moving 'Y' before the splat shuffle, we are implicitly assuming + // that it is not undef/poison at the splat index. + if (isSplatValue(OtherOp, SplatIndex)) { + std::swap(Y, OtherOp); + } else if (!isSplatValue(Y, SplatIndex)) { + return nullptr; + } + + // X and Y are splatted values, so perform the binary operation on those + // values followed by a splat followed by the 2nd binary operation: + // bo (splat X), (bo Y, OtherOp) --> bo (splat (bo X, Y)), OtherOp + Value *NewBO = Builder.CreateBinOp(Opcode, X, Y); + SmallVector<int, 8> NewMask(MaskC.size(), SplatIndex); + Value *NewSplat = Builder.CreateShuffleVector(NewBO, NewMask); + Instruction *R = BinaryOperator::Create(Opcode, NewSplat, OtherOp); + + // Intersect FMF on both new binops. Other (poison-generating) flags are + // dropped to be safe. + if (isa<FPMathOperator>(R)) { + R->copyFastMathFlags(&Inst); + R->andIRFlags(RHS); + } + if (auto *NewInstBO = dyn_cast<BinaryOperator>(NewBO)) + NewInstBO->copyIRFlags(R); + return R; + } + + return nullptr; +} + +/// Try to narrow the width of a binop if at least 1 operand is an extend of +/// of a value. This requires a potentially expensive known bits check to make +/// sure the narrow op does not overflow. +Instruction *InstCombinerImpl::narrowMathIfNoOverflow(BinaryOperator &BO) { + // We need at least one extended operand. + Value *Op0 = BO.getOperand(0), *Op1 = BO.getOperand(1); + + // If this is a sub, we swap the operands since we always want an extension + // on the RHS. The LHS can be an extension or a constant. + if (BO.getOpcode() == Instruction::Sub) + std::swap(Op0, Op1); + + Value *X; + bool IsSext = match(Op0, m_SExt(m_Value(X))); + if (!IsSext && !match(Op0, m_ZExt(m_Value(X)))) + return nullptr; + + // If both operands are the same extension from the same source type and we + // can eliminate at least one (hasOneUse), this might work. + CastInst::CastOps CastOpc = IsSext ? Instruction::SExt : Instruction::ZExt; + Value *Y; + if (!(match(Op1, m_ZExtOrSExt(m_Value(Y))) && X->getType() == Y->getType() && + cast<Operator>(Op1)->getOpcode() == CastOpc && + (Op0->hasOneUse() || Op1->hasOneUse()))) { + // If that did not match, see if we have a suitable constant operand. + // Truncating and extending must produce the same constant. + Constant *WideC; + if (!Op0->hasOneUse() || !match(Op1, m_Constant(WideC))) + return nullptr; + Constant *NarrowC = ConstantExpr::getTrunc(WideC, X->getType()); + if (ConstantExpr::getCast(CastOpc, NarrowC, BO.getType()) != WideC) + return nullptr; + Y = NarrowC; + } + + // Swap back now that we found our operands. + if (BO.getOpcode() == Instruction::Sub) + std::swap(X, Y); + + // Both operands have narrow versions. Last step: the math must not overflow + // in the narrow width. + if (!willNotOverflow(BO.getOpcode(), X, Y, BO, IsSext)) + return nullptr; + + // bo (ext X), (ext Y) --> ext (bo X, Y) + // bo (ext X), C --> ext (bo X, C') + Value *NarrowBO = Builder.CreateBinOp(BO.getOpcode(), X, Y, "narrow"); + if (auto *NewBinOp = dyn_cast<BinaryOperator>(NarrowBO)) { + if (IsSext) + NewBinOp->setHasNoSignedWrap(); + else + NewBinOp->setHasNoUnsignedWrap(); + } + return CastInst::Create(CastOpc, NarrowBO, BO.getType()); +} + +static bool isMergedGEPInBounds(GEPOperator &GEP1, GEPOperator &GEP2) { + // At least one GEP must be inbounds. + if (!GEP1.isInBounds() && !GEP2.isInBounds()) + return false; + + return (GEP1.isInBounds() || GEP1.hasAllZeroIndices()) && + (GEP2.isInBounds() || GEP2.hasAllZeroIndices()); +} + +/// Thread a GEP operation with constant indices through the constant true/false +/// arms of a select. +static Instruction *foldSelectGEP(GetElementPtrInst &GEP, + InstCombiner::BuilderTy &Builder) { + if (!GEP.hasAllConstantIndices()) + return nullptr; + + Instruction *Sel; + Value *Cond; + Constant *TrueC, *FalseC; + if (!match(GEP.getPointerOperand(), m_Instruction(Sel)) || + !match(Sel, + m_Select(m_Value(Cond), m_Constant(TrueC), m_Constant(FalseC)))) + return nullptr; + + // gep (select Cond, TrueC, FalseC), IndexC --> select Cond, TrueC', FalseC' + // Propagate 'inbounds' and metadata from existing instructions. + // Note: using IRBuilder to create the constants for efficiency. + SmallVector<Value *, 4> IndexC(GEP.indices()); + bool IsInBounds = GEP.isInBounds(); + Type *Ty = GEP.getSourceElementType(); + Value *NewTrueC = Builder.CreateGEP(Ty, TrueC, IndexC, "", IsInBounds); + Value *NewFalseC = Builder.CreateGEP(Ty, FalseC, IndexC, "", IsInBounds); + return SelectInst::Create(Cond, NewTrueC, NewFalseC, "", nullptr, Sel); +} + +Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, + GEPOperator *Src) { + // Combine Indices - If the source pointer to this getelementptr instruction + // is a getelementptr instruction with matching element type, combine the + // indices of the two getelementptr instructions into a single instruction. + if (!shouldMergeGEPs(*cast<GEPOperator>(&GEP), *Src)) + return nullptr; + + if (Src->getResultElementType() == GEP.getSourceElementType() && + Src->getNumOperands() == 2 && GEP.getNumOperands() == 2 && + Src->hasOneUse()) { + Value *GO1 = GEP.getOperand(1); + Value *SO1 = Src->getOperand(1); + + if (LI) { + // Try to reassociate loop invariant GEP chains to enable LICM. + if (Loop *L = LI->getLoopFor(GEP.getParent())) { + // Reassociate the two GEPs if SO1 is variant in the loop and GO1 is + // invariant: this breaks the dependence between GEPs and allows LICM + // to hoist the invariant part out of the loop. + if (L->isLoopInvariant(GO1) && !L->isLoopInvariant(SO1)) { + // The swapped GEPs are inbounds if both original GEPs are inbounds + // and the sign of the offsets is the same. For simplicity, only + // handle both offsets being non-negative. + bool IsInBounds = Src->isInBounds() && GEP.isInBounds() && + isKnownNonNegative(SO1, DL, 0, &AC, &GEP, &DT) && + isKnownNonNegative(GO1, DL, 0, &AC, &GEP, &DT); + // Put NewSrc at same location as %src. + Builder.SetInsertPoint(cast<Instruction>(Src)); + Value *NewSrc = Builder.CreateGEP(GEP.getSourceElementType(), + Src->getPointerOperand(), GO1, + Src->getName(), IsInBounds); + GetElementPtrInst *NewGEP = GetElementPtrInst::Create( + GEP.getSourceElementType(), NewSrc, {SO1}); + NewGEP->setIsInBounds(IsInBounds); + return NewGEP; + } + } + } + } + + // Note that if our source is a gep chain itself then we wait for that + // chain to be resolved before we perform this transformation. This + // avoids us creating a TON of code in some cases. + if (auto *SrcGEP = dyn_cast<GEPOperator>(Src->getOperand(0))) + if (SrcGEP->getNumOperands() == 2 && shouldMergeGEPs(*Src, *SrcGEP)) + return nullptr; // Wait until our source is folded to completion. + + // For constant GEPs, use a more general offset-based folding approach. + // Only do this for opaque pointers, as the result element type may change. + Type *PtrTy = Src->getType()->getScalarType(); + if (PtrTy->isOpaquePointerTy() && GEP.hasAllConstantIndices() && + (Src->hasOneUse() || Src->hasAllConstantIndices())) { + // Split Src into a variable part and a constant suffix. + gep_type_iterator GTI = gep_type_begin(*Src); + Type *BaseType = GTI.getIndexedType(); + bool IsFirstType = true; + unsigned NumVarIndices = 0; + for (auto Pair : enumerate(Src->indices())) { + if (!isa<ConstantInt>(Pair.value())) { + BaseType = GTI.getIndexedType(); + IsFirstType = false; + NumVarIndices = Pair.index() + 1; + } + ++GTI; + } + + // Determine the offset for the constant suffix of Src. + APInt Offset(DL.getIndexTypeSizeInBits(PtrTy), 0); + if (NumVarIndices != Src->getNumIndices()) { + // FIXME: getIndexedOffsetInType() does not handled scalable vectors. + if (isa<ScalableVectorType>(BaseType)) + return nullptr; + + SmallVector<Value *> ConstantIndices; + if (!IsFirstType) + ConstantIndices.push_back( + Constant::getNullValue(Type::getInt32Ty(GEP.getContext()))); + append_range(ConstantIndices, drop_begin(Src->indices(), NumVarIndices)); + Offset += DL.getIndexedOffsetInType(BaseType, ConstantIndices); + } + + // Add the offset for GEP (which is fully constant). + if (!GEP.accumulateConstantOffset(DL, Offset)) + return nullptr; + + APInt OffsetOld = Offset; + // Convert the total offset back into indices. + SmallVector<APInt> ConstIndices = + DL.getGEPIndicesForOffset(BaseType, Offset); + if (!Offset.isZero() || (!IsFirstType && !ConstIndices[0].isZero())) { + // If both GEP are constant-indexed, and cannot be merged in either way, + // convert them to a GEP of i8. + if (Src->hasAllConstantIndices()) + return isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)) + ? GetElementPtrInst::CreateInBounds( + Builder.getInt8Ty(), Src->getOperand(0), + Builder.getInt(OffsetOld), GEP.getName()) + : GetElementPtrInst::Create( + Builder.getInt8Ty(), Src->getOperand(0), + Builder.getInt(OffsetOld), GEP.getName()); + return nullptr; + } + + bool IsInBounds = isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)); + SmallVector<Value *> Indices; + append_range(Indices, drop_end(Src->indices(), + Src->getNumIndices() - NumVarIndices)); + for (const APInt &Idx : drop_begin(ConstIndices, !IsFirstType)) { + Indices.push_back(ConstantInt::get(GEP.getContext(), Idx)); + // Even if the total offset is inbounds, we may end up representing it + // by first performing a larger negative offset, and then a smaller + // positive one. The large negative offset might go out of bounds. Only + // preserve inbounds if all signs are the same. + IsInBounds &= Idx.isNonNegative() == ConstIndices[0].isNonNegative(); + } + + return IsInBounds + ? GetElementPtrInst::CreateInBounds(Src->getSourceElementType(), + Src->getOperand(0), Indices, + GEP.getName()) + : GetElementPtrInst::Create(Src->getSourceElementType(), + Src->getOperand(0), Indices, + GEP.getName()); + } + + if (Src->getResultElementType() != GEP.getSourceElementType()) + return nullptr; + + SmallVector<Value*, 8> Indices; + + // Find out whether the last index in the source GEP is a sequential idx. + bool EndsWithSequential = false; + for (gep_type_iterator I = gep_type_begin(*Src), E = gep_type_end(*Src); + I != E; ++I) + EndsWithSequential = I.isSequential(); + + // Can we combine the two pointer arithmetics offsets? + if (EndsWithSequential) { + // Replace: gep (gep %P, long B), long A, ... + // With: T = long A+B; gep %P, T, ... + Value *SO1 = Src->getOperand(Src->getNumOperands()-1); + Value *GO1 = GEP.getOperand(1); + + // If they aren't the same type, then the input hasn't been processed + // by the loop above yet (which canonicalizes sequential index types to + // intptr_t). Just avoid transforming this until the input has been + // normalized. + if (SO1->getType() != GO1->getType()) + return nullptr; + + Value *Sum = + simplifyAddInst(GO1, SO1, false, false, SQ.getWithInstruction(&GEP)); + // Only do the combine when we are sure the cost after the + // merge is never more than that before the merge. + if (Sum == nullptr) + return nullptr; + + // Update the GEP in place if possible. + if (Src->getNumOperands() == 2) { + GEP.setIsInBounds(isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP))); + replaceOperand(GEP, 0, Src->getOperand(0)); + replaceOperand(GEP, 1, Sum); + return &GEP; + } + Indices.append(Src->op_begin()+1, Src->op_end()-1); + Indices.push_back(Sum); + Indices.append(GEP.op_begin()+2, GEP.op_end()); + } else if (isa<Constant>(*GEP.idx_begin()) && + cast<Constant>(*GEP.idx_begin())->isNullValue() && + Src->getNumOperands() != 1) { + // Otherwise we can do the fold if the first index of the GEP is a zero + Indices.append(Src->op_begin()+1, Src->op_end()); + Indices.append(GEP.idx_begin()+1, GEP.idx_end()); + } + + if (!Indices.empty()) + return isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)) + ? GetElementPtrInst::CreateInBounds( + Src->getSourceElementType(), Src->getOperand(0), Indices, + GEP.getName()) + : GetElementPtrInst::Create(Src->getSourceElementType(), + Src->getOperand(0), Indices, + GEP.getName()); + + return nullptr; +} + +// Note that we may have also stripped an address space cast in between. +Instruction *InstCombinerImpl::visitGEPOfBitcast(BitCastInst *BCI, + GetElementPtrInst &GEP) { + // With opaque pointers, there is no pointer element type we can use to + // adjust the GEP type. + PointerType *SrcType = cast<PointerType>(BCI->getSrcTy()); + if (SrcType->isOpaque()) + return nullptr; + + Type *GEPEltType = GEP.getSourceElementType(); + Type *SrcEltType = SrcType->getNonOpaquePointerElementType(); + Value *SrcOp = BCI->getOperand(0); + + // GEP directly using the source operand if this GEP is accessing an element + // of a bitcasted pointer to vector or array of the same dimensions: + // gep (bitcast <c x ty>* X to [c x ty]*), Y, Z --> gep X, Y, Z + // gep (bitcast [c x ty]* X to <c x ty>*), Y, Z --> gep X, Y, Z + auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy, + const DataLayout &DL) { + auto *VecVTy = cast<FixedVectorType>(VecTy); + return ArrTy->getArrayElementType() == VecVTy->getElementType() && + ArrTy->getArrayNumElements() == VecVTy->getNumElements() && + DL.getTypeAllocSize(ArrTy) == DL.getTypeAllocSize(VecTy); + }; + if (GEP.getNumOperands() == 3 && + ((GEPEltType->isArrayTy() && isa<FixedVectorType>(SrcEltType) && + areMatchingArrayAndVecTypes(GEPEltType, SrcEltType, DL)) || + (isa<FixedVectorType>(GEPEltType) && SrcEltType->isArrayTy() && + areMatchingArrayAndVecTypes(SrcEltType, GEPEltType, DL)))) { + + // Create a new GEP here, as using `setOperand()` followed by + // `setSourceElementType()` won't actually update the type of the + // existing GEP Value. Causing issues if this Value is accessed when + // constructing an AddrSpaceCastInst + SmallVector<Value *, 8> Indices(GEP.indices()); + Value *NGEP = + Builder.CreateGEP(SrcEltType, SrcOp, Indices, "", GEP.isInBounds()); + NGEP->takeName(&GEP); + + // Preserve GEP address space to satisfy users + if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) + return new AddrSpaceCastInst(NGEP, GEP.getType()); + + return replaceInstUsesWith(GEP, NGEP); + } + + // See if we can simplify: + // X = bitcast A* to B* + // Y = gep X, <...constant indices...> + // into a gep of the original struct. This is important for SROA and alias + // analysis of unions. If "A" is also a bitcast, wait for A/X to be merged. + unsigned OffsetBits = DL.getIndexTypeSizeInBits(GEP.getType()); + APInt Offset(OffsetBits, 0); + + // If the bitcast argument is an allocation, The bitcast is for convertion + // to actual type of allocation. Removing such bitcasts, results in having + // GEPs with i8* base and pure byte offsets. That means GEP is not aware of + // struct or array hierarchy. + // By avoiding such GEPs, phi translation and MemoryDependencyAnalysis have + // a better chance to succeed. + if (!isa<BitCastInst>(SrcOp) && GEP.accumulateConstantOffset(DL, Offset) && + !isAllocationFn(SrcOp, &TLI)) { + // If this GEP instruction doesn't move the pointer, just replace the GEP + // with a bitcast of the real input to the dest type. + if (!Offset) { + // If the bitcast is of an allocation, and the allocation will be + // converted to match the type of the cast, don't touch this. + if (isa<AllocaInst>(SrcOp)) { + // See if the bitcast simplifies, if so, don't nuke this GEP yet. + if (Instruction *I = visitBitCast(*BCI)) { + if (I != BCI) { + I->takeName(BCI); + BCI->getParent()->getInstList().insert(BCI->getIterator(), I); + replaceInstUsesWith(*BCI, I); + } + return &GEP; + } + } + + if (SrcType->getPointerAddressSpace() != GEP.getAddressSpace()) + return new AddrSpaceCastInst(SrcOp, GEP.getType()); + return new BitCastInst(SrcOp, GEP.getType()); + } + + // Otherwise, if the offset is non-zero, we need to find out if there is a + // field at Offset in 'A's type. If so, we can pull the cast through the + // GEP. + SmallVector<Value *, 8> NewIndices; + if (findElementAtOffset(SrcType, Offset.getSExtValue(), NewIndices, DL)) { + Value *NGEP = Builder.CreateGEP(SrcEltType, SrcOp, NewIndices, "", + GEP.isInBounds()); + + if (NGEP->getType() == GEP.getType()) + return replaceInstUsesWith(GEP, NGEP); + NGEP->takeName(&GEP); + + if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) + return new AddrSpaceCastInst(NGEP, GEP.getType()); + return new BitCastInst(NGEP, GEP.getType()); + } + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { + Value *PtrOp = GEP.getOperand(0); + SmallVector<Value *, 8> Indices(GEP.indices()); + Type *GEPType = GEP.getType(); + Type *GEPEltType = GEP.getSourceElementType(); + bool IsGEPSrcEleScalable = isa<ScalableVectorType>(GEPEltType); + if (Value *V = simplifyGEPInst(GEPEltType, PtrOp, Indices, GEP.isInBounds(), + SQ.getWithInstruction(&GEP))) + return replaceInstUsesWith(GEP, V); + + // For vector geps, use the generic demanded vector support. + // Skip if GEP return type is scalable. The number of elements is unknown at + // compile-time. + if (auto *GEPFVTy = dyn_cast<FixedVectorType>(GEPType)) { + auto VWidth = GEPFVTy->getNumElements(); + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(&GEP, AllOnesEltMask, + UndefElts)) { + if (V != &GEP) + return replaceInstUsesWith(GEP, V); + return &GEP; + } + + // TODO: 1) Scalarize splat operands, 2) scalarize entire instruction if + // possible (decide on canonical form for pointer broadcast), 3) exploit + // undef elements to decrease demanded bits + } + + // Eliminate unneeded casts for indices, and replace indices which displace + // by multiples of a zero size type with zero. + bool MadeChange = false; + + // Index width may not be the same width as pointer width. + // Data layout chooses the right type based on supported integer types. + Type *NewScalarIndexTy = + DL.getIndexType(GEP.getPointerOperandType()->getScalarType()); + + gep_type_iterator GTI = gep_type_begin(GEP); + for (User::op_iterator I = GEP.op_begin() + 1, E = GEP.op_end(); I != E; + ++I, ++GTI) { + // Skip indices into struct types. + if (GTI.isStruct()) + continue; + + Type *IndexTy = (*I)->getType(); + Type *NewIndexType = + IndexTy->isVectorTy() + ? VectorType::get(NewScalarIndexTy, + cast<VectorType>(IndexTy)->getElementCount()) + : NewScalarIndexTy; + + // If the element type has zero size then any index over it is equivalent + // to an index of zero, so replace it with zero if it is not zero already. + Type *EltTy = GTI.getIndexedType(); + if (EltTy->isSized() && DL.getTypeAllocSize(EltTy).isZero()) + if (!isa<Constant>(*I) || !match(I->get(), m_Zero())) { + *I = Constant::getNullValue(NewIndexType); + MadeChange = true; + } + + if (IndexTy != NewIndexType) { + // If we are using a wider index than needed for this platform, shrink + // it to what we need. If narrower, sign-extend it to what we need. + // This explicit cast can make subsequent optimizations more obvious. + *I = Builder.CreateIntCast(*I, NewIndexType, true); + MadeChange = true; + } + } + if (MadeChange) + return &GEP; + + // Check to see if the inputs to the PHI node are getelementptr instructions. + if (auto *PN = dyn_cast<PHINode>(PtrOp)) { + auto *Op1 = dyn_cast<GetElementPtrInst>(PN->getOperand(0)); + if (!Op1) + return nullptr; + + // Don't fold a GEP into itself through a PHI node. This can only happen + // through the back-edge of a loop. Folding a GEP into itself means that + // the value of the previous iteration needs to be stored in the meantime, + // thus requiring an additional register variable to be live, but not + // actually achieving anything (the GEP still needs to be executed once per + // loop iteration). + if (Op1 == &GEP) + return nullptr; + + int DI = -1; + + for (auto I = PN->op_begin()+1, E = PN->op_end(); I !=E; ++I) { + auto *Op2 = dyn_cast<GetElementPtrInst>(*I); + if (!Op2 || Op1->getNumOperands() != Op2->getNumOperands() || + Op1->getSourceElementType() != Op2->getSourceElementType()) + return nullptr; + + // As for Op1 above, don't try to fold a GEP into itself. + if (Op2 == &GEP) + return nullptr; + + // Keep track of the type as we walk the GEP. + Type *CurTy = nullptr; + + for (unsigned J = 0, F = Op1->getNumOperands(); J != F; ++J) { + if (Op1->getOperand(J)->getType() != Op2->getOperand(J)->getType()) + return nullptr; + + if (Op1->getOperand(J) != Op2->getOperand(J)) { + if (DI == -1) { + // We have not seen any differences yet in the GEPs feeding the + // PHI yet, so we record this one if it is allowed to be a + // variable. + + // The first two arguments can vary for any GEP, the rest have to be + // static for struct slots + if (J > 1) { + assert(CurTy && "No current type?"); + if (CurTy->isStructTy()) + return nullptr; + } + + DI = J; + } else { + // The GEP is different by more than one input. While this could be + // extended to support GEPs that vary by more than one variable it + // doesn't make sense since it greatly increases the complexity and + // would result in an R+R+R addressing mode which no backend + // directly supports and would need to be broken into several + // simpler instructions anyway. + return nullptr; + } + } + + // Sink down a layer of the type for the next iteration. + if (J > 0) { + if (J == 1) { + CurTy = Op1->getSourceElementType(); + } else { + CurTy = + GetElementPtrInst::getTypeAtIndex(CurTy, Op1->getOperand(J)); + } + } + } + } + + // If not all GEPs are identical we'll have to create a new PHI node. + // Check that the old PHI node has only one use so that it will get + // removed. + if (DI != -1 && !PN->hasOneUse()) + return nullptr; + + auto *NewGEP = cast<GetElementPtrInst>(Op1->clone()); + if (DI == -1) { + // All the GEPs feeding the PHI are identical. Clone one down into our + // BB so that it can be merged with the current GEP. + } else { + // All the GEPs feeding the PHI differ at a single offset. Clone a GEP + // into the current block so it can be merged, and create a new PHI to + // set that index. + PHINode *NewPN; + { + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(PN); + NewPN = Builder.CreatePHI(Op1->getOperand(DI)->getType(), + PN->getNumOperands()); + } + + for (auto &I : PN->operands()) + NewPN->addIncoming(cast<GEPOperator>(I)->getOperand(DI), + PN->getIncomingBlock(I)); + + NewGEP->setOperand(DI, NewPN); + } + + GEP.getParent()->getInstList().insert( + GEP.getParent()->getFirstInsertionPt(), NewGEP); + replaceOperand(GEP, 0, NewGEP); + PtrOp = NewGEP; + } + + if (auto *Src = dyn_cast<GEPOperator>(PtrOp)) + if (Instruction *I = visitGEPOfGEP(GEP, Src)) + return I; + + // Skip if GEP source element type is scalable. The type alloc size is unknown + // at compile-time. + if (GEP.getNumIndices() == 1 && !IsGEPSrcEleScalable) { + unsigned AS = GEP.getPointerAddressSpace(); + if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == + DL.getIndexSizeInBits(AS)) { + uint64_t TyAllocSize = DL.getTypeAllocSize(GEPEltType).getFixedSize(); + + bool Matched = false; + uint64_t C; + Value *V = nullptr; + if (TyAllocSize == 1) { + V = GEP.getOperand(1); + Matched = true; + } else if (match(GEP.getOperand(1), + m_AShr(m_Value(V), m_ConstantInt(C)))) { + if (TyAllocSize == 1ULL << C) + Matched = true; + } else if (match(GEP.getOperand(1), + m_SDiv(m_Value(V), m_ConstantInt(C)))) { + if (TyAllocSize == C) + Matched = true; + } + + // Canonicalize (gep i8* X, (ptrtoint Y)-(ptrtoint X)) to (bitcast Y), but + // only if both point to the same underlying object (otherwise provenance + // is not necessarily retained). + Value *Y; + Value *X = GEP.getOperand(0); + if (Matched && + match(V, m_Sub(m_PtrToInt(m_Value(Y)), m_PtrToInt(m_Specific(X)))) && + getUnderlyingObject(X) == getUnderlyingObject(Y)) + return CastInst::CreatePointerBitCastOrAddrSpaceCast(Y, GEPType); + } + } + + // We do not handle pointer-vector geps here. + if (GEPType->isVectorTy()) + return nullptr; + + // Handle gep(bitcast x) and gep(gep x, 0, 0, 0). + Value *StrippedPtr = PtrOp->stripPointerCasts(); + PointerType *StrippedPtrTy = cast<PointerType>(StrippedPtr->getType()); + + // TODO: The basic approach of these folds is not compatible with opaque + // pointers, because we can't use bitcasts as a hint for a desirable GEP + // type. Instead, we should perform canonicalization directly on the GEP + // type. For now, skip these. + if (StrippedPtr != PtrOp && !StrippedPtrTy->isOpaque()) { + bool HasZeroPointerIndex = false; + Type *StrippedPtrEltTy = StrippedPtrTy->getNonOpaquePointerElementType(); + + if (auto *C = dyn_cast<ConstantInt>(GEP.getOperand(1))) + HasZeroPointerIndex = C->isZero(); + + // Transform: GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... + // into : GEP [10 x i8]* X, i32 0, ... + // + // Likewise, transform: GEP (bitcast i8* X to [0 x i8]*), i32 0, ... + // into : GEP i8* X, ... + // + // This occurs when the program declares an array extern like "int X[];" + if (HasZeroPointerIndex) { + if (auto *CATy = dyn_cast<ArrayType>(GEPEltType)) { + // GEP (bitcast i8* X to [0 x i8]*), i32 0, ... ? + if (CATy->getElementType() == StrippedPtrEltTy) { + // -> GEP i8* X, ... + SmallVector<Value *, 8> Idx(drop_begin(GEP.indices())); + GetElementPtrInst *Res = GetElementPtrInst::Create( + StrippedPtrEltTy, StrippedPtr, Idx, GEP.getName()); + Res->setIsInBounds(GEP.isInBounds()); + if (StrippedPtrTy->getAddressSpace() == GEP.getAddressSpace()) + return Res; + // Insert Res, and create an addrspacecast. + // e.g., + // GEP (addrspacecast i8 addrspace(1)* X to [0 x i8]*), i32 0, ... + // -> + // %0 = GEP i8 addrspace(1)* X, ... + // addrspacecast i8 addrspace(1)* %0 to i8* + return new AddrSpaceCastInst(Builder.Insert(Res), GEPType); + } + + if (auto *XATy = dyn_cast<ArrayType>(StrippedPtrEltTy)) { + // GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... ? + if (CATy->getElementType() == XATy->getElementType()) { + // -> GEP [10 x i8]* X, i32 0, ... + // At this point, we know that the cast source type is a pointer + // to an array of the same type as the destination pointer + // array. Because the array type is never stepped over (there + // is a leading zero) we can fold the cast into this GEP. + if (StrippedPtrTy->getAddressSpace() == GEP.getAddressSpace()) { + GEP.setSourceElementType(XATy); + return replaceOperand(GEP, 0, StrippedPtr); + } + // Cannot replace the base pointer directly because StrippedPtr's + // address space is different. Instead, create a new GEP followed by + // an addrspacecast. + // e.g., + // GEP (addrspacecast [10 x i8] addrspace(1)* X to [0 x i8]*), + // i32 0, ... + // -> + // %0 = GEP [10 x i8] addrspace(1)* X, ... + // addrspacecast i8 addrspace(1)* %0 to i8* + SmallVector<Value *, 8> Idx(GEP.indices()); + Value *NewGEP = + Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx, + GEP.getName(), GEP.isInBounds()); + return new AddrSpaceCastInst(NewGEP, GEPType); + } + } + } + } else if (GEP.getNumOperands() == 2 && !IsGEPSrcEleScalable) { + // Skip if GEP source element type is scalable. The type alloc size is + // unknown at compile-time. + // Transform things like: %t = getelementptr i32* + // bitcast ([2 x i32]* %str to i32*), i32 %V into: %t1 = getelementptr [2 + // x i32]* %str, i32 0, i32 %V; bitcast + if (StrippedPtrEltTy->isArrayTy() && + DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) == + DL.getTypeAllocSize(GEPEltType)) { + Type *IdxType = DL.getIndexType(GEPType); + Value *Idx[2] = {Constant::getNullValue(IdxType), GEP.getOperand(1)}; + Value *NewGEP = Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx, + GEP.getName(), GEP.isInBounds()); + + // V and GEP are both pointer types --> BitCast + return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, GEPType); + } + + // Transform things like: + // %V = mul i64 %N, 4 + // %t = getelementptr i8* bitcast (i32* %arr to i8*), i32 %V + // into: %t1 = getelementptr i32* %arr, i32 %N; bitcast + if (GEPEltType->isSized() && StrippedPtrEltTy->isSized()) { + // Check that changing the type amounts to dividing the index by a scale + // factor. + uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedSize(); + uint64_t SrcSize = DL.getTypeAllocSize(StrippedPtrEltTy).getFixedSize(); + if (ResSize && SrcSize % ResSize == 0) { + Value *Idx = GEP.getOperand(1); + unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); + uint64_t Scale = SrcSize / ResSize; + + // Earlier transforms ensure that the index has the right type + // according to Data Layout, which considerably simplifies the + // logic by eliminating implicit casts. + assert(Idx->getType() == DL.getIndexType(GEPType) && + "Index type does not match the Data Layout preferences"); + + bool NSW; + if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) { + // Successfully decomposed Idx as NewIdx * Scale, form a new GEP. + // If the multiplication NewIdx * Scale may overflow then the new + // GEP may not be "inbounds". + Value *NewGEP = + Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, NewIdx, + GEP.getName(), GEP.isInBounds() && NSW); + + // The NewGEP must be pointer typed, so must the old one -> BitCast + return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, + GEPType); + } + } + } + + // Similarly, transform things like: + // getelementptr i8* bitcast ([100 x double]* X to i8*), i32 %tmp + // (where tmp = 8*tmp2) into: + // getelementptr [100 x double]* %arr, i32 0, i32 %tmp2; bitcast + if (GEPEltType->isSized() && StrippedPtrEltTy->isSized() && + StrippedPtrEltTy->isArrayTy()) { + // Check that changing to the array element type amounts to dividing the + // index by a scale factor. + uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedSize(); + uint64_t ArrayEltSize = + DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) + .getFixedSize(); + if (ResSize && ArrayEltSize % ResSize == 0) { + Value *Idx = GEP.getOperand(1); + unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); + uint64_t Scale = ArrayEltSize / ResSize; + + // Earlier transforms ensure that the index has the right type + // according to the Data Layout, which considerably simplifies + // the logic by eliminating implicit casts. + assert(Idx->getType() == DL.getIndexType(GEPType) && + "Index type does not match the Data Layout preferences"); + + bool NSW; + if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) { + // Successfully decomposed Idx as NewIdx * Scale, form a new GEP. + // If the multiplication NewIdx * Scale may overflow then the new + // GEP may not be "inbounds". + Type *IndTy = DL.getIndexType(GEPType); + Value *Off[2] = {Constant::getNullValue(IndTy), NewIdx}; + + Value *NewGEP = + Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Off, + GEP.getName(), GEP.isInBounds() && NSW); + // The NewGEP must be pointer typed, so must the old one -> BitCast + return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, + GEPType); + } + } + } + } + } + + // addrspacecast between types is canonicalized as a bitcast, then an + // addrspacecast. To take advantage of the below bitcast + struct GEP, look + // through the addrspacecast. + Value *ASCStrippedPtrOp = PtrOp; + if (auto *ASC = dyn_cast<AddrSpaceCastInst>(PtrOp)) { + // X = bitcast A addrspace(1)* to B addrspace(1)* + // Y = addrspacecast A addrspace(1)* to B addrspace(2)* + // Z = gep Y, <...constant indices...> + // Into an addrspacecasted GEP of the struct. + if (auto *BC = dyn_cast<BitCastInst>(ASC->getOperand(0))) + ASCStrippedPtrOp = BC; + } + + if (auto *BCI = dyn_cast<BitCastInst>(ASCStrippedPtrOp)) + if (Instruction *I = visitGEPOfBitcast(BCI, GEP)) + return I; + + if (!GEP.isInBounds()) { + unsigned IdxWidth = + DL.getIndexSizeInBits(PtrOp->getType()->getPointerAddressSpace()); + APInt BasePtrOffset(IdxWidth, 0); + Value *UnderlyingPtrOp = + PtrOp->stripAndAccumulateInBoundsConstantOffsets(DL, + BasePtrOffset); + if (auto *AI = dyn_cast<AllocaInst>(UnderlyingPtrOp)) { + if (GEP.accumulateConstantOffset(DL, BasePtrOffset) && + BasePtrOffset.isNonNegative()) { + APInt AllocSize( + IdxWidth, + DL.getTypeAllocSize(AI->getAllocatedType()).getKnownMinSize()); + if (BasePtrOffset.ule(AllocSize)) { + return GetElementPtrInst::CreateInBounds( + GEP.getSourceElementType(), PtrOp, Indices, GEP.getName()); + } + } + } + } + + if (Instruction *R = foldSelectGEP(GEP, Builder)) + return R; + + return nullptr; +} + +static bool isNeverEqualToUnescapedAlloc(Value *V, const TargetLibraryInfo &TLI, + Instruction *AI) { + if (isa<ConstantPointerNull>(V)) + return true; + if (auto *LI = dyn_cast<LoadInst>(V)) + return isa<GlobalVariable>(LI->getPointerOperand()); + // Two distinct allocations will never be equal. + return isAllocLikeFn(V, &TLI) && V != AI; +} + +/// Given a call CB which uses an address UsedV, return true if we can prove the +/// call's only possible effect is storing to V. +static bool isRemovableWrite(CallBase &CB, Value *UsedV, + const TargetLibraryInfo &TLI) { + if (!CB.use_empty()) + // TODO: add recursion if returned attribute is present + return false; + + if (CB.isTerminator()) + // TODO: remove implementation restriction + return false; + + if (!CB.willReturn() || !CB.doesNotThrow()) + return false; + + // If the only possible side effect of the call is writing to the alloca, + // and the result isn't used, we can safely remove any reads implied by the + // call including those which might read the alloca itself. + Optional<MemoryLocation> Dest = MemoryLocation::getForDest(&CB, TLI); + return Dest && Dest->Ptr == UsedV; +} + +static bool isAllocSiteRemovable(Instruction *AI, + SmallVectorImpl<WeakTrackingVH> &Users, + const TargetLibraryInfo &TLI) { + SmallVector<Instruction*, 4> Worklist; + const Optional<StringRef> Family = getAllocationFamily(AI, &TLI); + Worklist.push_back(AI); + + do { + Instruction *PI = Worklist.pop_back_val(); + for (User *U : PI->users()) { + Instruction *I = cast<Instruction>(U); + switch (I->getOpcode()) { + default: + // Give up the moment we see something we can't handle. + return false; + + case Instruction::AddrSpaceCast: + case Instruction::BitCast: + case Instruction::GetElementPtr: + Users.emplace_back(I); + Worklist.push_back(I); + continue; + + case Instruction::ICmp: { + ICmpInst *ICI = cast<ICmpInst>(I); + // We can fold eq/ne comparisons with null to false/true, respectively. + // We also fold comparisons in some conditions provided the alloc has + // not escaped (see isNeverEqualToUnescapedAlloc). + if (!ICI->isEquality()) + return false; + unsigned OtherIndex = (ICI->getOperand(0) == PI) ? 1 : 0; + if (!isNeverEqualToUnescapedAlloc(ICI->getOperand(OtherIndex), TLI, AI)) + return false; + Users.emplace_back(I); + continue; + } + + case Instruction::Call: + // Ignore no-op and store intrinsics. + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + default: + return false; + + case Intrinsic::memmove: + case Intrinsic::memcpy: + case Intrinsic::memset: { + MemIntrinsic *MI = cast<MemIntrinsic>(II); + if (MI->isVolatile() || MI->getRawDest() != PI) + return false; + LLVM_FALLTHROUGH; + } + case Intrinsic::assume: + case Intrinsic::invariant_start: + case Intrinsic::invariant_end: + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + case Intrinsic::objectsize: + Users.emplace_back(I); + continue; + case Intrinsic::launder_invariant_group: + case Intrinsic::strip_invariant_group: + Users.emplace_back(I); + Worklist.push_back(I); + continue; + } + } + + if (isRemovableWrite(*cast<CallBase>(I), PI, TLI)) { + Users.emplace_back(I); + continue; + } + + if (isFreeCall(I, &TLI) && getAllocationFamily(I, &TLI) == Family) { + assert(Family); + Users.emplace_back(I); + continue; + } + + if (isReallocLikeFn(I, &TLI) && + getAllocationFamily(I, &TLI) == Family) { + assert(Family); + Users.emplace_back(I); + Worklist.push_back(I); + continue; + } + + return false; + + case Instruction::Store: { + StoreInst *SI = cast<StoreInst>(I); + if (SI->isVolatile() || SI->getPointerOperand() != PI) + return false; + Users.emplace_back(I); + continue; + } + } + llvm_unreachable("missing a return?"); + } + } while (!Worklist.empty()); + return true; +} + +Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { + assert(isa<AllocaInst>(MI) || isAllocRemovable(&cast<CallBase>(MI), &TLI)); + + // If we have a malloc call which is only used in any amount of comparisons to + // null and free calls, delete the calls and replace the comparisons with true + // or false as appropriate. + + // This is based on the principle that we can substitute our own allocation + // function (which will never return null) rather than knowledge of the + // specific function being called. In some sense this can change the permitted + // outputs of a program (when we convert a malloc to an alloca, the fact that + // the allocation is now on the stack is potentially visible, for example), + // but we believe in a permissible manner. + SmallVector<WeakTrackingVH, 64> Users; + + // If we are removing an alloca with a dbg.declare, insert dbg.value calls + // before each store. + SmallVector<DbgVariableIntrinsic *, 8> DVIs; + std::unique_ptr<DIBuilder> DIB; + if (isa<AllocaInst>(MI)) { + findDbgUsers(DVIs, &MI); + DIB.reset(new DIBuilder(*MI.getModule(), /*AllowUnresolved=*/false)); + } + + if (isAllocSiteRemovable(&MI, Users, TLI)) { + for (unsigned i = 0, e = Users.size(); i != e; ++i) { + // Lowering all @llvm.objectsize calls first because they may + // use a bitcast/GEP of the alloca we are removing. + if (!Users[i]) + continue; + + Instruction *I = cast<Instruction>(&*Users[i]); + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + if (II->getIntrinsicID() == Intrinsic::objectsize) { + Value *Result = + lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/true); + replaceInstUsesWith(*I, Result); + eraseInstFromFunction(*I); + Users[i] = nullptr; // Skip examining in the next loop. + } + } + } + for (unsigned i = 0, e = Users.size(); i != e; ++i) { + if (!Users[i]) + continue; + + Instruction *I = cast<Instruction>(&*Users[i]); + + if (ICmpInst *C = dyn_cast<ICmpInst>(I)) { + replaceInstUsesWith(*C, + ConstantInt::get(Type::getInt1Ty(C->getContext()), + C->isFalseWhenEqual())); + } else if (auto *SI = dyn_cast<StoreInst>(I)) { + for (auto *DVI : DVIs) + if (DVI->isAddressOfVariable()) + ConvertDebugDeclareToDebugValue(DVI, SI, *DIB); + } else { + // Casts, GEP, or anything else: we're about to delete this instruction, + // so it can not have any valid uses. + replaceInstUsesWith(*I, PoisonValue::get(I->getType())); + } + eraseInstFromFunction(*I); + } + + if (InvokeInst *II = dyn_cast<InvokeInst>(&MI)) { + // Replace invoke with a NOP intrinsic to maintain the original CFG + Module *M = II->getModule(); + Function *F = Intrinsic::getDeclaration(M, Intrinsic::donothing); + InvokeInst::Create(F, II->getNormalDest(), II->getUnwindDest(), + None, "", II->getParent()); + } + + // Remove debug intrinsics which describe the value contained within the + // alloca. In addition to removing dbg.{declare,addr} which simply point to + // the alloca, remove dbg.value(<alloca>, ..., DW_OP_deref)'s as well, e.g.: + // + // ``` + // define void @foo(i32 %0) { + // %a = alloca i32 ; Deleted. + // store i32 %0, i32* %a + // dbg.value(i32 %0, "arg0") ; Not deleted. + // dbg.value(i32* %a, "arg0", DW_OP_deref) ; Deleted. + // call void @trivially_inlinable_no_op(i32* %a) + // ret void + // } + // ``` + // + // This may not be required if we stop describing the contents of allocas + // using dbg.value(<alloca>, ..., DW_OP_deref), but we currently do this in + // the LowerDbgDeclare utility. + // + // If there is a dead store to `%a` in @trivially_inlinable_no_op, the + // "arg0" dbg.value may be stale after the call. However, failing to remove + // the DW_OP_deref dbg.value causes large gaps in location coverage. + for (auto *DVI : DVIs) + if (DVI->isAddressOfVariable() || DVI->getExpression()->startsWithDeref()) + DVI->eraseFromParent(); + + return eraseInstFromFunction(MI); + } + return nullptr; +} + +/// Move the call to free before a NULL test. +/// +/// Check if this free is accessed after its argument has been test +/// against NULL (property 0). +/// If yes, it is legal to move this call in its predecessor block. +/// +/// The move is performed only if the block containing the call to free +/// will be removed, i.e.: +/// 1. it has only one predecessor P, and P has two successors +/// 2. it contains the call, noops, and an unconditional branch +/// 3. its successor is the same as its predecessor's successor +/// +/// The profitability is out-of concern here and this function should +/// be called only if the caller knows this transformation would be +/// profitable (e.g., for code size). +static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI, + const DataLayout &DL) { + Value *Op = FI.getArgOperand(0); + BasicBlock *FreeInstrBB = FI.getParent(); + BasicBlock *PredBB = FreeInstrBB->getSinglePredecessor(); + + // Validate part of constraint #1: Only one predecessor + // FIXME: We can extend the number of predecessor, but in that case, we + // would duplicate the call to free in each predecessor and it may + // not be profitable even for code size. + if (!PredBB) + return nullptr; + + // Validate constraint #2: Does this block contains only the call to + // free, noops, and an unconditional branch? + BasicBlock *SuccBB; + Instruction *FreeInstrBBTerminator = FreeInstrBB->getTerminator(); + if (!match(FreeInstrBBTerminator, m_UnconditionalBr(SuccBB))) + return nullptr; + + // If there are only 2 instructions in the block, at this point, + // this is the call to free and unconditional. + // If there are more than 2 instructions, check that they are noops + // i.e., they won't hurt the performance of the generated code. + if (FreeInstrBB->size() != 2) { + for (const Instruction &Inst : FreeInstrBB->instructionsWithoutDebug()) { + if (&Inst == &FI || &Inst == FreeInstrBBTerminator) + continue; + auto *Cast = dyn_cast<CastInst>(&Inst); + if (!Cast || !Cast->isNoopCast(DL)) + return nullptr; + } + } + // Validate the rest of constraint #1 by matching on the pred branch. + Instruction *TI = PredBB->getTerminator(); + BasicBlock *TrueBB, *FalseBB; + ICmpInst::Predicate Pred; + if (!match(TI, m_Br(m_ICmp(Pred, + m_CombineOr(m_Specific(Op), + m_Specific(Op->stripPointerCasts())), + m_Zero()), + TrueBB, FalseBB))) + return nullptr; + if (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE) + return nullptr; + + // Validate constraint #3: Ensure the null case just falls through. + if (SuccBB != (Pred == ICmpInst::ICMP_EQ ? TrueBB : FalseBB)) + return nullptr; + assert(FreeInstrBB == (Pred == ICmpInst::ICMP_EQ ? FalseBB : TrueBB) && + "Broken CFG: missing edge from predecessor to successor"); + + // At this point, we know that everything in FreeInstrBB can be moved + // before TI. + for (Instruction &Instr : llvm::make_early_inc_range(*FreeInstrBB)) { + if (&Instr == FreeInstrBBTerminator) + break; + Instr.moveBefore(TI); + } + assert(FreeInstrBB->size() == 1 && + "Only the branch instruction should remain"); + + // Now that we've moved the call to free before the NULL check, we have to + // remove any attributes on its parameter that imply it's non-null, because + // those attributes might have only been valid because of the NULL check, and + // we can get miscompiles if we keep them. This is conservative if non-null is + // also implied by something other than the NULL check, but it's guaranteed to + // be correct, and the conservativeness won't matter in practice, since the + // attributes are irrelevant for the call to free itself and the pointer + // shouldn't be used after the call. + AttributeList Attrs = FI.getAttributes(); + Attrs = Attrs.removeParamAttribute(FI.getContext(), 0, Attribute::NonNull); + Attribute Dereferenceable = Attrs.getParamAttr(0, Attribute::Dereferenceable); + if (Dereferenceable.isValid()) { + uint64_t Bytes = Dereferenceable.getDereferenceableBytes(); + Attrs = Attrs.removeParamAttribute(FI.getContext(), 0, + Attribute::Dereferenceable); + Attrs = Attrs.addDereferenceableOrNullParamAttr(FI.getContext(), 0, Bytes); + } + FI.setAttributes(Attrs); + + return &FI; +} + +Instruction *InstCombinerImpl::visitFree(CallInst &FI) { + Value *Op = FI.getArgOperand(0); + + // free undef -> unreachable. + if (isa<UndefValue>(Op)) { + // Leave a marker since we can't modify the CFG here. + CreateNonTerminatorUnreachable(&FI); + return eraseInstFromFunction(FI); + } + + // If we have 'free null' delete the instruction. This can happen in stl code + // when lots of inlining happens. + if (isa<ConstantPointerNull>(Op)) + return eraseInstFromFunction(FI); + + // If we had free(realloc(...)) with no intervening uses, then eliminate the + // realloc() entirely. + if (CallInst *CI = dyn_cast<CallInst>(Op)) { + if (CI->hasOneUse() && isReallocLikeFn(CI, &TLI)) { + return eraseInstFromFunction( + *replaceInstUsesWith(*CI, CI->getOperand(0))); + } + } + + // If we optimize for code size, try to move the call to free before the null + // test so that simplify cfg can remove the empty block and dead code + // elimination the branch. I.e., helps to turn something like: + // if (foo) free(foo); + // into + // free(foo); + // + // Note that we can only do this for 'free' and not for any flavor of + // 'operator delete'; there is no 'operator delete' symbol for which we are + // permitted to invent a call, even if we're passing in a null pointer. + if (MinimizeSize) { + LibFunc Func; + if (TLI.getLibFunc(FI, Func) && TLI.has(Func) && Func == LibFunc_free) + if (Instruction *I = tryToMoveFreeBeforeNullTest(FI, DL)) + return I; + } + + return nullptr; +} + +static bool isMustTailCall(Value *V) { + if (auto *CI = dyn_cast<CallInst>(V)) + return CI->isMustTailCall(); + return false; +} + +Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) { + if (RI.getNumOperands() == 0) // ret void + return nullptr; + + Value *ResultOp = RI.getOperand(0); + Type *VTy = ResultOp->getType(); + if (!VTy->isIntegerTy() || isa<Constant>(ResultOp)) + return nullptr; + + // Don't replace result of musttail calls. + if (isMustTailCall(ResultOp)) + return nullptr; + + // There might be assume intrinsics dominating this return that completely + // determine the value. If so, constant fold it. + KnownBits Known = computeKnownBits(ResultOp, 0, &RI); + if (Known.isConstant()) + return replaceOperand(RI, 0, + Constant::getIntegerValue(VTy, Known.getConstant())); + + return nullptr; +} + +// WARNING: keep in sync with SimplifyCFGOpt::simplifyUnreachable()! +Instruction *InstCombinerImpl::visitUnreachableInst(UnreachableInst &I) { + // Try to remove the previous instruction if it must lead to unreachable. + // This includes instructions like stores and "llvm.assume" that may not get + // removed by simple dead code elimination. + while (Instruction *Prev = I.getPrevNonDebugInstruction()) { + // While we theoretically can erase EH, that would result in a block that + // used to start with an EH no longer starting with EH, which is invalid. + // To make it valid, we'd need to fixup predecessors to no longer refer to + // this block, but that changes CFG, which is not allowed in InstCombine. + if (Prev->isEHPad()) + return nullptr; // Can not drop any more instructions. We're done here. + + if (!isGuaranteedToTransferExecutionToSuccessor(Prev)) + return nullptr; // Can not drop any more instructions. We're done here. + // Otherwise, this instruction can be freely erased, + // even if it is not side-effect free. + + // A value may still have uses before we process it here (for example, in + // another unreachable block), so convert those to poison. + replaceInstUsesWith(*Prev, PoisonValue::get(Prev->getType())); + eraseInstFromFunction(*Prev); + } + assert(I.getParent()->sizeWithoutDebug() == 1 && "The block is now empty."); + // FIXME: recurse into unconditional predecessors? + return nullptr; +} + +Instruction *InstCombinerImpl::visitUnconditionalBranchInst(BranchInst &BI) { + assert(BI.isUnconditional() && "Only for unconditional branches."); + + // If this store is the second-to-last instruction in the basic block + // (excluding debug info and bitcasts of pointers) and if the block ends with + // an unconditional branch, try to move the store to the successor block. + + auto GetLastSinkableStore = [](BasicBlock::iterator BBI) { + auto IsNoopInstrForStoreMerging = [](BasicBlock::iterator BBI) { + return BBI->isDebugOrPseudoInst() || + (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy()); + }; + + BasicBlock::iterator FirstInstr = BBI->getParent()->begin(); + do { + if (BBI != FirstInstr) + --BBI; + } while (BBI != FirstInstr && IsNoopInstrForStoreMerging(BBI)); + + return dyn_cast<StoreInst>(BBI); + }; + + if (StoreInst *SI = GetLastSinkableStore(BasicBlock::iterator(BI))) + if (mergeStoreIntoSuccessor(*SI)) + return &BI; + + return nullptr; +} + +Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { + if (BI.isUnconditional()) + return visitUnconditionalBranchInst(BI); + + // Change br (not X), label True, label False to: br X, label False, True + Value *X = nullptr; + if (match(&BI, m_Br(m_Not(m_Value(X)), m_BasicBlock(), m_BasicBlock())) && + !isa<Constant>(X)) { + // Swap Destinations and condition... + BI.swapSuccessors(); + return replaceOperand(BI, 0, X); + } + + // If the condition is irrelevant, remove the use so that other + // transforms on the condition become more effective. + if (!isa<ConstantInt>(BI.getCondition()) && + BI.getSuccessor(0) == BI.getSuccessor(1)) + return replaceOperand( + BI, 0, ConstantInt::getFalse(BI.getCondition()->getType())); + + // Canonicalize, for example, fcmp_one -> fcmp_oeq. + CmpInst::Predicate Pred; + if (match(&BI, m_Br(m_OneUse(m_FCmp(Pred, m_Value(), m_Value())), + m_BasicBlock(), m_BasicBlock())) && + !isCanonicalPredicate(Pred)) { + // Swap destinations and condition. + CmpInst *Cond = cast<CmpInst>(BI.getCondition()); + Cond->setPredicate(CmpInst::getInversePredicate(Pred)); + BI.swapSuccessors(); + Worklist.push(Cond); + return &BI; + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { + Value *Cond = SI.getCondition(); + Value *Op0; + ConstantInt *AddRHS; + if (match(Cond, m_Add(m_Value(Op0), m_ConstantInt(AddRHS)))) { + // Change 'switch (X+4) case 1:' into 'switch (X) case -3'. + for (auto Case : SI.cases()) { + Constant *NewCase = ConstantExpr::getSub(Case.getCaseValue(), AddRHS); + assert(isa<ConstantInt>(NewCase) && + "Result of expression should be constant"); + Case.setValue(cast<ConstantInt>(NewCase)); + } + return replaceOperand(SI, 0, Op0); + } + + KnownBits Known = computeKnownBits(Cond, 0, &SI); + unsigned LeadingKnownZeros = Known.countMinLeadingZeros(); + unsigned LeadingKnownOnes = Known.countMinLeadingOnes(); + + // Compute the number of leading bits we can ignore. + // TODO: A better way to determine this would use ComputeNumSignBits(). + for (auto &C : SI.cases()) { + LeadingKnownZeros = std::min( + LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros()); + LeadingKnownOnes = std::min( + LeadingKnownOnes, C.getCaseValue()->getValue().countLeadingOnes()); + } + + unsigned NewWidth = Known.getBitWidth() - std::max(LeadingKnownZeros, LeadingKnownOnes); + + // Shrink the condition operand if the new type is smaller than the old type. + // But do not shrink to a non-standard type, because backend can't generate + // good code for that yet. + // TODO: We can make it aggressive again after fixing PR39569. + if (NewWidth > 0 && NewWidth < Known.getBitWidth() && + shouldChangeType(Known.getBitWidth(), NewWidth)) { + IntegerType *Ty = IntegerType::get(SI.getContext(), NewWidth); + Builder.SetInsertPoint(&SI); + Value *NewCond = Builder.CreateTrunc(Cond, Ty, "trunc"); + + for (auto Case : SI.cases()) { + APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth); + Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); + } + return replaceOperand(SI, 0, NewCond); + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { + Value *Agg = EV.getAggregateOperand(); + + if (!EV.hasIndices()) + return replaceInstUsesWith(EV, Agg); + + if (Value *V = simplifyExtractValueInst(Agg, EV.getIndices(), + SQ.getWithInstruction(&EV))) + return replaceInstUsesWith(EV, V); + + if (InsertValueInst *IV = dyn_cast<InsertValueInst>(Agg)) { + // We're extracting from an insertvalue instruction, compare the indices + const unsigned *exti, *exte, *insi, *inse; + for (exti = EV.idx_begin(), insi = IV->idx_begin(), + exte = EV.idx_end(), inse = IV->idx_end(); + exti != exte && insi != inse; + ++exti, ++insi) { + if (*insi != *exti) + // The insert and extract both reference distinctly different elements. + // This means the extract is not influenced by the insert, and we can + // replace the aggregate operand of the extract with the aggregate + // operand of the insert. i.e., replace + // %I = insertvalue { i32, { i32 } } %A, { i32 } { i32 42 }, 1 + // %E = extractvalue { i32, { i32 } } %I, 0 + // with + // %E = extractvalue { i32, { i32 } } %A, 0 + return ExtractValueInst::Create(IV->getAggregateOperand(), + EV.getIndices()); + } + if (exti == exte && insi == inse) + // Both iterators are at the end: Index lists are identical. Replace + // %B = insertvalue { i32, { i32 } } %A, i32 42, 1, 0 + // %C = extractvalue { i32, { i32 } } %B, 1, 0 + // with "i32 42" + return replaceInstUsesWith(EV, IV->getInsertedValueOperand()); + if (exti == exte) { + // The extract list is a prefix of the insert list. i.e. replace + // %I = insertvalue { i32, { i32 } } %A, i32 42, 1, 0 + // %E = extractvalue { i32, { i32 } } %I, 1 + // with + // %X = extractvalue { i32, { i32 } } %A, 1 + // %E = insertvalue { i32 } %X, i32 42, 0 + // by switching the order of the insert and extract (though the + // insertvalue should be left in, since it may have other uses). + Value *NewEV = Builder.CreateExtractValue(IV->getAggregateOperand(), + EV.getIndices()); + return InsertValueInst::Create(NewEV, IV->getInsertedValueOperand(), + makeArrayRef(insi, inse)); + } + if (insi == inse) + // The insert list is a prefix of the extract list + // We can simply remove the common indices from the extract and make it + // operate on the inserted value instead of the insertvalue result. + // i.e., replace + // %I = insertvalue { i32, { i32 } } %A, { i32 } { i32 42 }, 1 + // %E = extractvalue { i32, { i32 } } %I, 1, 0 + // with + // %E extractvalue { i32 } { i32 42 }, 0 + return ExtractValueInst::Create(IV->getInsertedValueOperand(), + makeArrayRef(exti, exte)); + } + if (WithOverflowInst *WO = dyn_cast<WithOverflowInst>(Agg)) { + // extractvalue (any_mul_with_overflow X, -1), 0 --> -X + Intrinsic::ID OvID = WO->getIntrinsicID(); + if (*EV.idx_begin() == 0 && + (OvID == Intrinsic::smul_with_overflow || + OvID == Intrinsic::umul_with_overflow) && + match(WO->getArgOperand(1), m_AllOnes())) { + return BinaryOperator::CreateNeg(WO->getArgOperand(0)); + } + + // We're extracting from an overflow intrinsic, see if we're the only user, + // which allows us to simplify multiple result intrinsics to simpler + // things that just get one value. + if (WO->hasOneUse()) { + // Check if we're grabbing only the result of a 'with overflow' intrinsic + // and replace it with a traditional binary instruction. + if (*EV.idx_begin() == 0) { + Instruction::BinaryOps BinOp = WO->getBinaryOp(); + Value *LHS = WO->getLHS(), *RHS = WO->getRHS(); + // Replace the old instruction's uses with poison. + replaceInstUsesWith(*WO, PoisonValue::get(WO->getType())); + eraseInstFromFunction(*WO); + return BinaryOperator::Create(BinOp, LHS, RHS); + } + + assert(*EV.idx_begin() == 1 && + "unexpected extract index for overflow inst"); + + // If only the overflow result is used, and the right hand side is a + // constant (or constant splat), we can remove the intrinsic by directly + // checking for overflow. + const APInt *C; + if (match(WO->getRHS(), m_APInt(C))) { + // Compute the no-wrap range for LHS given RHS=C, then construct an + // equivalent icmp, potentially using an offset. + ConstantRange NWR = + ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C, + WO->getNoWrapKind()); + + CmpInst::Predicate Pred; + APInt NewRHSC, Offset; + NWR.getEquivalentICmp(Pred, NewRHSC, Offset); + auto *OpTy = WO->getRHS()->getType(); + auto *NewLHS = WO->getLHS(); + if (Offset != 0) + NewLHS = Builder.CreateAdd(NewLHS, ConstantInt::get(OpTy, Offset)); + return new ICmpInst(ICmpInst::getInversePredicate(Pred), NewLHS, + ConstantInt::get(OpTy, NewRHSC)); + } + } + } + if (LoadInst *L = dyn_cast<LoadInst>(Agg)) + // If the (non-volatile) load only has one use, we can rewrite this to a + // load from a GEP. This reduces the size of the load. If a load is used + // only by extractvalue instructions then this either must have been + // optimized before, or it is a struct with padding, in which case we + // don't want to do the transformation as it loses padding knowledge. + if (L->isSimple() && L->hasOneUse()) { + // extractvalue has integer indices, getelementptr has Value*s. Convert. + SmallVector<Value*, 4> Indices; + // Prefix an i32 0 since we need the first element. + Indices.push_back(Builder.getInt32(0)); + for (unsigned Idx : EV.indices()) + Indices.push_back(Builder.getInt32(Idx)); + + // We need to insert these at the location of the old load, not at that of + // the extractvalue. + Builder.SetInsertPoint(L); + Value *GEP = Builder.CreateInBoundsGEP(L->getType(), + L->getPointerOperand(), Indices); + Instruction *NL = Builder.CreateLoad(EV.getType(), GEP); + // Whatever aliasing information we had for the orignal load must also + // hold for the smaller load, so propagate the annotations. + NL->setAAMetadata(L->getAAMetadata()); + // Returning the load directly will cause the main loop to insert it in + // the wrong spot, so use replaceInstUsesWith(). + return replaceInstUsesWith(EV, NL); + } + // We could simplify extracts from other values. Note that nested extracts may + // already be simplified implicitly by the above: extract (extract (insert) ) + // will be translated into extract ( insert ( extract ) ) first and then just + // the value inserted, if appropriate. Similarly for extracts from single-use + // loads: extract (extract (load)) will be translated to extract (load (gep)) + // and if again single-use then via load (gep (gep)) to load (gep). + // However, double extracts from e.g. function arguments or return values + // aren't handled yet. + return nullptr; +} + +/// Return 'true' if the given typeinfo will match anything. +static bool isCatchAll(EHPersonality Personality, Constant *TypeInfo) { + switch (Personality) { + case EHPersonality::GNU_C: + case EHPersonality::GNU_C_SjLj: + case EHPersonality::Rust: + // The GCC C EH and Rust personality only exists to support cleanups, so + // it's not clear what the semantics of catch clauses are. + return false; + case EHPersonality::Unknown: + return false; + case EHPersonality::GNU_Ada: + // While __gnat_all_others_value will match any Ada exception, it doesn't + // match foreign exceptions (or didn't, before gcc-4.7). + return false; + case EHPersonality::GNU_CXX: + case EHPersonality::GNU_CXX_SjLj: + case EHPersonality::GNU_ObjC: + case EHPersonality::MSVC_X86SEH: + case EHPersonality::MSVC_TableSEH: + case EHPersonality::MSVC_CXX: + case EHPersonality::CoreCLR: + case EHPersonality::Wasm_CXX: + case EHPersonality::XL_CXX: + return TypeInfo->isNullValue(); + } + llvm_unreachable("invalid enum"); +} + +static bool shorter_filter(const Value *LHS, const Value *RHS) { + return + cast<ArrayType>(LHS->getType())->getNumElements() + < + cast<ArrayType>(RHS->getType())->getNumElements(); +} + +Instruction *InstCombinerImpl::visitLandingPadInst(LandingPadInst &LI) { + // The logic here should be correct for any real-world personality function. + // However if that turns out not to be true, the offending logic can always + // be conditioned on the personality function, like the catch-all logic is. + EHPersonality Personality = + classifyEHPersonality(LI.getParent()->getParent()->getPersonalityFn()); + + // Simplify the list of clauses, eg by removing repeated catch clauses + // (these are often created by inlining). + bool MakeNewInstruction = false; // If true, recreate using the following: + SmallVector<Constant *, 16> NewClauses; // - Clauses for the new instruction; + bool CleanupFlag = LI.isCleanup(); // - The new instruction is a cleanup. + + SmallPtrSet<Value *, 16> AlreadyCaught; // Typeinfos known caught already. + for (unsigned i = 0, e = LI.getNumClauses(); i != e; ++i) { + bool isLastClause = i + 1 == e; + if (LI.isCatch(i)) { + // A catch clause. + Constant *CatchClause = LI.getClause(i); + Constant *TypeInfo = CatchClause->stripPointerCasts(); + + // If we already saw this clause, there is no point in having a second + // copy of it. + if (AlreadyCaught.insert(TypeInfo).second) { + // This catch clause was not already seen. + NewClauses.push_back(CatchClause); + } else { + // Repeated catch clause - drop the redundant copy. + MakeNewInstruction = true; + } + + // If this is a catch-all then there is no point in keeping any following + // clauses or marking the landingpad as having a cleanup. + if (isCatchAll(Personality, TypeInfo)) { + if (!isLastClause) + MakeNewInstruction = true; + CleanupFlag = false; + break; + } + } else { + // A filter clause. If any of the filter elements were already caught + // then they can be dropped from the filter. It is tempting to try to + // exploit the filter further by saying that any typeinfo that does not + // occur in the filter can't be caught later (and thus can be dropped). + // However this would be wrong, since typeinfos can match without being + // equal (for example if one represents a C++ class, and the other some + // class derived from it). + assert(LI.isFilter(i) && "Unsupported landingpad clause!"); + Constant *FilterClause = LI.getClause(i); + ArrayType *FilterType = cast<ArrayType>(FilterClause->getType()); + unsigned NumTypeInfos = FilterType->getNumElements(); + + // An empty filter catches everything, so there is no point in keeping any + // following clauses or marking the landingpad as having a cleanup. By + // dealing with this case here the following code is made a bit simpler. + if (!NumTypeInfos) { + NewClauses.push_back(FilterClause); + if (!isLastClause) + MakeNewInstruction = true; + CleanupFlag = false; + break; + } + + bool MakeNewFilter = false; // If true, make a new filter. + SmallVector<Constant *, 16> NewFilterElts; // New elements. + if (isa<ConstantAggregateZero>(FilterClause)) { + // Not an empty filter - it contains at least one null typeinfo. + assert(NumTypeInfos > 0 && "Should have handled empty filter already!"); + Constant *TypeInfo = + Constant::getNullValue(FilterType->getElementType()); + // If this typeinfo is a catch-all then the filter can never match. + if (isCatchAll(Personality, TypeInfo)) { + // Throw the filter away. + MakeNewInstruction = true; + continue; + } + + // There is no point in having multiple copies of this typeinfo, so + // discard all but the first copy if there is more than one. + NewFilterElts.push_back(TypeInfo); + if (NumTypeInfos > 1) + MakeNewFilter = true; + } else { + ConstantArray *Filter = cast<ConstantArray>(FilterClause); + SmallPtrSet<Value *, 16> SeenInFilter; // For uniquing the elements. + NewFilterElts.reserve(NumTypeInfos); + + // Remove any filter elements that were already caught or that already + // occurred in the filter. While there, see if any of the elements are + // catch-alls. If so, the filter can be discarded. + bool SawCatchAll = false; + for (unsigned j = 0; j != NumTypeInfos; ++j) { + Constant *Elt = Filter->getOperand(j); + Constant *TypeInfo = Elt->stripPointerCasts(); + if (isCatchAll(Personality, TypeInfo)) { + // This element is a catch-all. Bail out, noting this fact. + SawCatchAll = true; + break; + } + + // Even if we've seen a type in a catch clause, we don't want to + // remove it from the filter. An unexpected type handler may be + // set up for a call site which throws an exception of the same + // type caught. In order for the exception thrown by the unexpected + // handler to propagate correctly, the filter must be correctly + // described for the call site. + // + // Example: + // + // void unexpected() { throw 1;} + // void foo() throw (int) { + // std::set_unexpected(unexpected); + // try { + // throw 2.0; + // } catch (int i) {} + // } + + // There is no point in having multiple copies of the same typeinfo in + // a filter, so only add it if we didn't already. + if (SeenInFilter.insert(TypeInfo).second) + NewFilterElts.push_back(cast<Constant>(Elt)); + } + // A filter containing a catch-all cannot match anything by definition. + if (SawCatchAll) { + // Throw the filter away. + MakeNewInstruction = true; + continue; + } + + // If we dropped something from the filter, make a new one. + if (NewFilterElts.size() < NumTypeInfos) + MakeNewFilter = true; + } + if (MakeNewFilter) { + FilterType = ArrayType::get(FilterType->getElementType(), + NewFilterElts.size()); + FilterClause = ConstantArray::get(FilterType, NewFilterElts); + MakeNewInstruction = true; + } + + NewClauses.push_back(FilterClause); + + // If the new filter is empty then it will catch everything so there is + // no point in keeping any following clauses or marking the landingpad + // as having a cleanup. The case of the original filter being empty was + // already handled above. + if (MakeNewFilter && !NewFilterElts.size()) { + assert(MakeNewInstruction && "New filter but not a new instruction!"); + CleanupFlag = false; + break; + } + } + } + + // If several filters occur in a row then reorder them so that the shortest + // filters come first (those with the smallest number of elements). This is + // advantageous because shorter filters are more likely to match, speeding up + // unwinding, but mostly because it increases the effectiveness of the other + // filter optimizations below. + for (unsigned i = 0, e = NewClauses.size(); i + 1 < e; ) { + unsigned j; + // Find the maximal 'j' s.t. the range [i, j) consists entirely of filters. + for (j = i; j != e; ++j) + if (!isa<ArrayType>(NewClauses[j]->getType())) + break; + + // Check whether the filters are already sorted by length. We need to know + // if sorting them is actually going to do anything so that we only make a + // new landingpad instruction if it does. + for (unsigned k = i; k + 1 < j; ++k) + if (shorter_filter(NewClauses[k+1], NewClauses[k])) { + // Not sorted, so sort the filters now. Doing an unstable sort would be + // correct too but reordering filters pointlessly might confuse users. + std::stable_sort(NewClauses.begin() + i, NewClauses.begin() + j, + shorter_filter); + MakeNewInstruction = true; + break; + } + + // Look for the next batch of filters. + i = j + 1; + } + + // If typeinfos matched if and only if equal, then the elements of a filter L + // that occurs later than a filter F could be replaced by the intersection of + // the elements of F and L. In reality two typeinfos can match without being + // equal (for example if one represents a C++ class, and the other some class + // derived from it) so it would be wrong to perform this transform in general. + // However the transform is correct and useful if F is a subset of L. In that + // case L can be replaced by F, and thus removed altogether since repeating a + // filter is pointless. So here we look at all pairs of filters F and L where + // L follows F in the list of clauses, and remove L if every element of F is + // an element of L. This can occur when inlining C++ functions with exception + // specifications. + for (unsigned i = 0; i + 1 < NewClauses.size(); ++i) { + // Examine each filter in turn. + Value *Filter = NewClauses[i]; + ArrayType *FTy = dyn_cast<ArrayType>(Filter->getType()); + if (!FTy) + // Not a filter - skip it. + continue; + unsigned FElts = FTy->getNumElements(); + // Examine each filter following this one. Doing this backwards means that + // we don't have to worry about filters disappearing under us when removed. + for (unsigned j = NewClauses.size() - 1; j != i; --j) { + Value *LFilter = NewClauses[j]; + ArrayType *LTy = dyn_cast<ArrayType>(LFilter->getType()); + if (!LTy) + // Not a filter - skip it. + continue; + // If Filter is a subset of LFilter, i.e. every element of Filter is also + // an element of LFilter, then discard LFilter. + SmallVectorImpl<Constant *>::iterator J = NewClauses.begin() + j; + // If Filter is empty then it is a subset of LFilter. + if (!FElts) { + // Discard LFilter. + NewClauses.erase(J); + MakeNewInstruction = true; + // Move on to the next filter. + continue; + } + unsigned LElts = LTy->getNumElements(); + // If Filter is longer than LFilter then it cannot be a subset of it. + if (FElts > LElts) + // Move on to the next filter. + continue; + // At this point we know that LFilter has at least one element. + if (isa<ConstantAggregateZero>(LFilter)) { // LFilter only contains zeros. + // Filter is a subset of LFilter iff Filter contains only zeros (as we + // already know that Filter is not longer than LFilter). + if (isa<ConstantAggregateZero>(Filter)) { + assert(FElts <= LElts && "Should have handled this case earlier!"); + // Discard LFilter. + NewClauses.erase(J); + MakeNewInstruction = true; + } + // Move on to the next filter. + continue; + } + ConstantArray *LArray = cast<ConstantArray>(LFilter); + if (isa<ConstantAggregateZero>(Filter)) { // Filter only contains zeros. + // Since Filter is non-empty and contains only zeros, it is a subset of + // LFilter iff LFilter contains a zero. + assert(FElts > 0 && "Should have eliminated the empty filter earlier!"); + for (unsigned l = 0; l != LElts; ++l) + if (LArray->getOperand(l)->isNullValue()) { + // LFilter contains a zero - discard it. + NewClauses.erase(J); + MakeNewInstruction = true; + break; + } + // Move on to the next filter. + continue; + } + // At this point we know that both filters are ConstantArrays. Loop over + // operands to see whether every element of Filter is also an element of + // LFilter. Since filters tend to be short this is probably faster than + // using a method that scales nicely. + ConstantArray *FArray = cast<ConstantArray>(Filter); + bool AllFound = true; + for (unsigned f = 0; f != FElts; ++f) { + Value *FTypeInfo = FArray->getOperand(f)->stripPointerCasts(); + AllFound = false; + for (unsigned l = 0; l != LElts; ++l) { + Value *LTypeInfo = LArray->getOperand(l)->stripPointerCasts(); + if (LTypeInfo == FTypeInfo) { + AllFound = true; + break; + } + } + if (!AllFound) + break; + } + if (AllFound) { + // Discard LFilter. + NewClauses.erase(J); + MakeNewInstruction = true; + } + // Move on to the next filter. + } + } + + // If we changed any of the clauses, replace the old landingpad instruction + // with a new one. + if (MakeNewInstruction) { + LandingPadInst *NLI = LandingPadInst::Create(LI.getType(), + NewClauses.size()); + for (unsigned i = 0, e = NewClauses.size(); i != e; ++i) + NLI->addClause(NewClauses[i]); + // A landing pad with no clauses must have the cleanup flag set. It is + // theoretically possible, though highly unlikely, that we eliminated all + // clauses. If so, force the cleanup flag to true. + if (NewClauses.empty()) + CleanupFlag = true; + NLI->setCleanup(CleanupFlag); + return NLI; + } + + // Even if none of the clauses changed, we may nonetheless have understood + // that the cleanup flag is pointless. Clear it if so. + if (LI.isCleanup() != CleanupFlag) { + assert(!CleanupFlag && "Adding a cleanup, not removing one?!"); + LI.setCleanup(CleanupFlag); + return &LI; + } + + return nullptr; +} + +Value * +InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { + // Try to push freeze through instructions that propagate but don't produce + // poison as far as possible. If an operand of freeze follows three + // conditions 1) one-use, 2) does not produce poison, and 3) has all but one + // guaranteed-non-poison operands then push the freeze through to the one + // operand that is not guaranteed non-poison. The actual transform is as + // follows. + // Op1 = ... ; Op1 can be posion + // Op0 = Inst(Op1, NonPoisonOps...) ; Op0 has only one use and only have + // ; single guaranteed-non-poison operands + // ... = Freeze(Op0) + // => + // Op1 = ... + // Op1.fr = Freeze(Op1) + // ... = Inst(Op1.fr, NonPoisonOps...) + auto *OrigOp = OrigFI.getOperand(0); + auto *OrigOpInst = dyn_cast<Instruction>(OrigOp); + + // While we could change the other users of OrigOp to use freeze(OrigOp), that + // potentially reduces their optimization potential, so let's only do this iff + // the OrigOp is only used by the freeze. + if (!OrigOpInst || !OrigOpInst->hasOneUse() || isa<PHINode>(OrigOp)) + return nullptr; + + // We can't push the freeze through an instruction which can itself create + // poison. If the only source of new poison is flags, we can simply + // strip them (since we know the only use is the freeze and nothing can + // benefit from them.) + if (canCreateUndefOrPoison(cast<Operator>(OrigOp), /*ConsiderFlags*/ false)) + return nullptr; + + // If operand is guaranteed not to be poison, there is no need to add freeze + // to the operand. So we first find the operand that is not guaranteed to be + // poison. + Use *MaybePoisonOperand = nullptr; + for (Use &U : OrigOpInst->operands()) { + if (isGuaranteedNotToBeUndefOrPoison(U.get())) + continue; + if (!MaybePoisonOperand) + MaybePoisonOperand = &U; + else + return nullptr; + } + + OrigOpInst->dropPoisonGeneratingFlags(); + + // If all operands are guaranteed to be non-poison, we can drop freeze. + if (!MaybePoisonOperand) + return OrigOp; + + Builder.SetInsertPoint(OrigOpInst); + auto *FrozenMaybePoisonOperand = Builder.CreateFreeze( + MaybePoisonOperand->get(), MaybePoisonOperand->get()->getName() + ".fr"); + + replaceUse(*MaybePoisonOperand, FrozenMaybePoisonOperand); + return OrigOp; +} + +Instruction *InstCombinerImpl::foldFreezeIntoRecurrence(FreezeInst &FI, + PHINode *PN) { + // Detect whether this is a recurrence with a start value and some number of + // backedge values. We'll check whether we can push the freeze through the + // backedge values (possibly dropping poison flags along the way) until we + // reach the phi again. In that case, we can move the freeze to the start + // value. + Use *StartU = nullptr; + SmallVector<Value *> Worklist; + for (Use &U : PN->incoming_values()) { + if (DT.dominates(PN->getParent(), PN->getIncomingBlock(U))) { + // Add backedge value to worklist. + Worklist.push_back(U.get()); + continue; + } + + // Don't bother handling multiple start values. + if (StartU) + return nullptr; + StartU = &U; + } + + if (!StartU || Worklist.empty()) + return nullptr; // Not a recurrence. + + Value *StartV = StartU->get(); + BasicBlock *StartBB = PN->getIncomingBlock(*StartU); + bool StartNeedsFreeze = !isGuaranteedNotToBeUndefOrPoison(StartV); + // We can't insert freeze if the the start value is the result of the + // terminator (e.g. an invoke). + if (StartNeedsFreeze && StartBB->getTerminator() == StartV) + return nullptr; + + SmallPtrSet<Value *, 32> Visited; + SmallVector<Instruction *> DropFlags; + while (!Worklist.empty()) { + Value *V = Worklist.pop_back_val(); + if (!Visited.insert(V).second) + continue; + + if (Visited.size() > 32) + return nullptr; // Limit the total number of values we inspect. + + // Assume that PN is non-poison, because it will be after the transform. + if (V == PN || isGuaranteedNotToBeUndefOrPoison(V)) + continue; + + Instruction *I = dyn_cast<Instruction>(V); + if (!I || canCreateUndefOrPoison(cast<Operator>(I), + /*ConsiderFlags*/ false)) + return nullptr; + + DropFlags.push_back(I); + append_range(Worklist, I->operands()); + } + + for (Instruction *I : DropFlags) + I->dropPoisonGeneratingFlags(); + + if (StartNeedsFreeze) { + Builder.SetInsertPoint(StartBB->getTerminator()); + Value *FrozenStartV = Builder.CreateFreeze(StartV, + StartV->getName() + ".fr"); + replaceUse(*StartU, FrozenStartV); + } + return replaceInstUsesWith(FI, PN); +} + +bool InstCombinerImpl::freezeOtherUses(FreezeInst &FI) { + Value *Op = FI.getOperand(0); + + if (isa<Constant>(Op) || Op->hasOneUse()) + return false; + + // Move the freeze directly after the definition of its operand, so that + // it dominates the maximum number of uses. Note that it may not dominate + // *all* uses if the operand is an invoke/callbr and the use is in a phi on + // the normal/default destination. This is why the domination check in the + // replacement below is still necessary. + Instruction *MoveBefore = nullptr; + if (isa<Argument>(Op)) { + MoveBefore = &FI.getFunction()->getEntryBlock().front(); + while (isa<AllocaInst>(MoveBefore)) + MoveBefore = MoveBefore->getNextNode(); + } else if (auto *PN = dyn_cast<PHINode>(Op)) { + MoveBefore = PN->getParent()->getFirstNonPHI(); + } else if (auto *II = dyn_cast<InvokeInst>(Op)) { + MoveBefore = II->getNormalDest()->getFirstNonPHI(); + } else if (auto *CB = dyn_cast<CallBrInst>(Op)) { + MoveBefore = CB->getDefaultDest()->getFirstNonPHI(); + } else { + auto *I = cast<Instruction>(Op); + assert(!I->isTerminator() && "Cannot be a terminator"); + MoveBefore = I->getNextNode(); + } + + bool Changed = false; + if (&FI != MoveBefore) { + FI.moveBefore(MoveBefore); + Changed = true; + } + + Op->replaceUsesWithIf(&FI, [&](Use &U) -> bool { + bool Dominates = DT.dominates(&FI, U); + Changed |= Dominates; + return Dominates; + }); + + return Changed; +} + +Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) { + Value *Op0 = I.getOperand(0); + + if (Value *V = simplifyFreezeInst(Op0, SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + // freeze (phi const, x) --> phi const, (freeze x) + if (auto *PN = dyn_cast<PHINode>(Op0)) { + if (Instruction *NV = foldOpIntoPhi(I, PN)) + return NV; + if (Instruction *NV = foldFreezeIntoRecurrence(I, PN)) + return NV; + } + + if (Value *NI = pushFreezeToPreventPoisonFromPropagating(I)) + return replaceInstUsesWith(I, NI); + + // If I is freeze(undef), check its uses and fold it to a fixed constant. + // - or: pick -1 + // - select's condition: if the true value is constant, choose it by making + // the condition true. + // - default: pick 0 + // + // Note that this transform is intentionally done here rather than + // via an analysis in InstSimplify or at individual user sites. That is + // because we must produce the same value for all uses of the freeze - + // it's the reason "freeze" exists! + // + // TODO: This could use getBinopAbsorber() / getBinopIdentity() to avoid + // duplicating logic for binops at least. + auto getUndefReplacement = [&I](Type *Ty) { + Constant *BestValue = nullptr; + Constant *NullValue = Constant::getNullValue(Ty); + for (const auto *U : I.users()) { + Constant *C = NullValue; + if (match(U, m_Or(m_Value(), m_Value()))) + C = ConstantInt::getAllOnesValue(Ty); + else if (match(U, m_Select(m_Specific(&I), m_Constant(), m_Value()))) + C = ConstantInt::getTrue(Ty); + + if (!BestValue) + BestValue = C; + else if (BestValue != C) + BestValue = NullValue; + } + assert(BestValue && "Must have at least one use"); + return BestValue; + }; + + if (match(Op0, m_Undef())) + return replaceInstUsesWith(I, getUndefReplacement(I.getType())); + + Constant *C; + if (match(Op0, m_Constant(C)) && C->containsUndefOrPoisonElement()) { + Constant *ReplaceC = getUndefReplacement(I.getType()->getScalarType()); + return replaceInstUsesWith(I, Constant::replaceUndefsWith(C, ReplaceC)); + } + + // Replace uses of Op with freeze(Op). + if (freezeOtherUses(I)) + return &I; + + return nullptr; +} + +/// Check for case where the call writes to an otherwise dead alloca. This +/// shows up for unused out-params in idiomatic C/C++ code. Note that this +/// helper *only* analyzes the write; doesn't check any other legality aspect. +static bool SoleWriteToDeadLocal(Instruction *I, TargetLibraryInfo &TLI) { + auto *CB = dyn_cast<CallBase>(I); + if (!CB) + // TODO: handle e.g. store to alloca here - only worth doing if we extend + // to allow reload along used path as described below. Otherwise, this + // is simply a store to a dead allocation which will be removed. + return false; + Optional<MemoryLocation> Dest = MemoryLocation::getForDest(CB, TLI); + if (!Dest) + return false; + auto *AI = dyn_cast<AllocaInst>(getUnderlyingObject(Dest->Ptr)); + if (!AI) + // TODO: allow malloc? + return false; + // TODO: allow memory access dominated by move point? Note that since AI + // could have a reference to itself captured by the call, we would need to + // account for cycles in doing so. + SmallVector<const User *> AllocaUsers; + SmallPtrSet<const User *, 4> Visited; + auto pushUsers = [&](const Instruction &I) { + for (const User *U : I.users()) { + if (Visited.insert(U).second) + AllocaUsers.push_back(U); + } + }; + pushUsers(*AI); + while (!AllocaUsers.empty()) { + auto *UserI = cast<Instruction>(AllocaUsers.pop_back_val()); + if (isa<BitCastInst>(UserI) || isa<GetElementPtrInst>(UserI) || + isa<AddrSpaceCastInst>(UserI)) { + pushUsers(*UserI); + continue; + } + if (UserI == CB) + continue; + // TODO: support lifetime.start/end here + return false; + } + return true; +} + +/// Try to move the specified instruction from its current block into the +/// beginning of DestBlock, which can only happen if it's safe to move the +/// instruction past all of the instructions between it and the end of its +/// block. +static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock, + TargetLibraryInfo &TLI) { + BasicBlock *SrcBlock = I->getParent(); + + // Cannot move control-flow-involving, volatile loads, vaarg, etc. + if (isa<PHINode>(I) || I->isEHPad() || I->mayThrow() || !I->willReturn() || + I->isTerminator()) + return false; + + // Do not sink static or dynamic alloca instructions. Static allocas must + // remain in the entry block, and dynamic allocas must not be sunk in between + // a stacksave / stackrestore pair, which would incorrectly shorten its + // lifetime. + if (isa<AllocaInst>(I)) + return false; + + // Do not sink into catchswitch blocks. + if (isa<CatchSwitchInst>(DestBlock->getTerminator())) + return false; + + // Do not sink convergent call instructions. + if (auto *CI = dyn_cast<CallInst>(I)) { + if (CI->isConvergent()) + return false; + } + + // Unless we can prove that the memory write isn't visibile except on the + // path we're sinking to, we must bail. + if (I->mayWriteToMemory()) { + if (!SoleWriteToDeadLocal(I, TLI)) + return false; + } + + // We can only sink load instructions if there is nothing between the load and + // the end of block that could change the value. + if (I->mayReadFromMemory()) { + // We don't want to do any sophisticated alias analysis, so we only check + // the instructions after I in I's parent block if we try to sink to its + // successor block. + if (DestBlock->getUniquePredecessor() != I->getParent()) + return false; + for (BasicBlock::iterator Scan = std::next(I->getIterator()), + E = I->getParent()->end(); + Scan != E; ++Scan) + if (Scan->mayWriteToMemory()) + return false; + } + + I->dropDroppableUses([DestBlock](const Use *U) { + if (auto *I = dyn_cast<Instruction>(U->getUser())) + return I->getParent() != DestBlock; + return true; + }); + /// FIXME: We could remove droppable uses that are not dominated by + /// the new position. + + BasicBlock::iterator InsertPos = DestBlock->getFirstInsertionPt(); + I->moveBefore(&*InsertPos); + ++NumSunkInst; + + // Also sink all related debug uses from the source basic block. Otherwise we + // get debug use before the def. Attempt to salvage debug uses first, to + // maximise the range variables have location for. If we cannot salvage, then + // mark the location undef: we know it was supposed to receive a new location + // here, but that computation has been sunk. + SmallVector<DbgVariableIntrinsic *, 2> DbgUsers; + findDbgUsers(DbgUsers, I); + // Process the sinking DbgUsers in reverse order, as we only want to clone the + // last appearing debug intrinsic for each given variable. + SmallVector<DbgVariableIntrinsic *, 2> DbgUsersToSink; + for (DbgVariableIntrinsic *DVI : DbgUsers) + if (DVI->getParent() == SrcBlock) + DbgUsersToSink.push_back(DVI); + llvm::sort(DbgUsersToSink, + [](auto *A, auto *B) { return B->comesBefore(A); }); + + SmallVector<DbgVariableIntrinsic *, 2> DIIClones; + SmallSet<DebugVariable, 4> SunkVariables; + for (auto User : DbgUsersToSink) { + // A dbg.declare instruction should not be cloned, since there can only be + // one per variable fragment. It should be left in the original place + // because the sunk instruction is not an alloca (otherwise we could not be + // here). + if (isa<DbgDeclareInst>(User)) + continue; + + DebugVariable DbgUserVariable = + DebugVariable(User->getVariable(), User->getExpression(), + User->getDebugLoc()->getInlinedAt()); + + if (!SunkVariables.insert(DbgUserVariable).second) + continue; + + DIIClones.emplace_back(cast<DbgVariableIntrinsic>(User->clone())); + if (isa<DbgDeclareInst>(User) && isa<CastInst>(I)) + DIIClones.back()->replaceVariableLocationOp(I, I->getOperand(0)); + LLVM_DEBUG(dbgs() << "CLONE: " << *DIIClones.back() << '\n'); + } + + // Perform salvaging without the clones, then sink the clones. + if (!DIIClones.empty()) { + salvageDebugInfoForDbgValues(*I, DbgUsers); + // The clones are in reverse order of original appearance, reverse again to + // maintain the original order. + for (auto &DIIClone : llvm::reverse(DIIClones)) { + DIIClone->insertBefore(&*InsertPos); + LLVM_DEBUG(dbgs() << "SINK: " << *DIIClone << '\n'); + } + } + + return true; +} + +bool InstCombinerImpl::run() { + while (!Worklist.isEmpty()) { + // Walk deferred instructions in reverse order, and push them to the + // worklist, which means they'll end up popped from the worklist in-order. + while (Instruction *I = Worklist.popDeferred()) { + // Check to see if we can DCE the instruction. We do this already here to + // reduce the number of uses and thus allow other folds to trigger. + // Note that eraseInstFromFunction() may push additional instructions on + // the deferred worklist, so this will DCE whole instruction chains. + if (isInstructionTriviallyDead(I, &TLI)) { + eraseInstFromFunction(*I); + ++NumDeadInst; + continue; + } + + Worklist.push(I); + } + + Instruction *I = Worklist.removeOne(); + if (I == nullptr) continue; // skip null values. + + // Check to see if we can DCE the instruction. + if (isInstructionTriviallyDead(I, &TLI)) { + eraseInstFromFunction(*I); + ++NumDeadInst; + continue; + } + + if (!DebugCounter::shouldExecute(VisitCounter)) + continue; + + // Instruction isn't dead, see if we can constant propagate it. + if (!I->use_empty() && + (I->getNumOperands() == 0 || isa<Constant>(I->getOperand(0)))) { + if (Constant *C = ConstantFoldInstruction(I, DL, &TLI)) { + LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *I + << '\n'); + + // Add operands to the worklist. + replaceInstUsesWith(*I, C); + ++NumConstProp; + if (isInstructionTriviallyDead(I, &TLI)) + eraseInstFromFunction(*I); + MadeIRChange = true; + continue; + } + } + + // See if we can trivially sink this instruction to its user if we can + // prove that the successor is not executed more frequently than our block. + // Return the UserBlock if successful. + auto getOptionalSinkBlockForInst = + [this](Instruction *I) -> Optional<BasicBlock *> { + if (!EnableCodeSinking) + return None; + + BasicBlock *BB = I->getParent(); + BasicBlock *UserParent = nullptr; + unsigned NumUsers = 0; + + for (auto *U : I->users()) { + if (U->isDroppable()) + continue; + if (NumUsers > MaxSinkNumUsers) + return None; + + Instruction *UserInst = cast<Instruction>(U); + // Special handling for Phi nodes - get the block the use occurs in. + if (PHINode *PN = dyn_cast<PHINode>(UserInst)) { + for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { + if (PN->getIncomingValue(i) == I) { + // Bail out if we have uses in different blocks. We don't do any + // sophisticated analysis (i.e finding NearestCommonDominator of + // these use blocks). + if (UserParent && UserParent != PN->getIncomingBlock(i)) + return None; + UserParent = PN->getIncomingBlock(i); + } + } + assert(UserParent && "expected to find user block!"); + } else { + if (UserParent && UserParent != UserInst->getParent()) + return None; + UserParent = UserInst->getParent(); + } + + // Make sure these checks are done only once, naturally we do the checks + // the first time we get the userparent, this will save compile time. + if (NumUsers == 0) { + // Try sinking to another block. If that block is unreachable, then do + // not bother. SimplifyCFG should handle it. + if (UserParent == BB || !DT.isReachableFromEntry(UserParent)) + return None; + + auto *Term = UserParent->getTerminator(); + // See if the user is one of our successors that has only one + // predecessor, so that we don't have to split the critical edge. + // Another option where we can sink is a block that ends with a + // terminator that does not pass control to other block (such as + // return or unreachable or resume). In this case: + // - I dominates the User (by SSA form); + // - the User will be executed at most once. + // So sinking I down to User is always profitable or neutral. + if (UserParent->getUniquePredecessor() != BB && !succ_empty(Term)) + return None; + + assert(DT.dominates(BB, UserParent) && "Dominance relation broken?"); + } + + NumUsers++; + } + + // No user or only has droppable users. + if (!UserParent) + return None; + + return UserParent; + }; + + auto OptBB = getOptionalSinkBlockForInst(I); + if (OptBB) { + auto *UserParent = *OptBB; + // Okay, the CFG is simple enough, try to sink this instruction. + if (TryToSinkInstruction(I, UserParent, TLI)) { + LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); + MadeIRChange = true; + // We'll add uses of the sunk instruction below, but since + // sinking can expose opportunities for it's *operands* add + // them to the worklist + for (Use &U : I->operands()) + if (Instruction *OpI = dyn_cast<Instruction>(U.get())) + Worklist.push(OpI); + } + } + + // Now that we have an instruction, try combining it to simplify it. + Builder.SetInsertPoint(I); + Builder.CollectMetadataToCopy( + I, {LLVMContext::MD_dbg, LLVMContext::MD_annotation}); + +#ifndef NDEBUG + std::string OrigI; +#endif + LLVM_DEBUG(raw_string_ostream SS(OrigI); I->print(SS); OrigI = SS.str();); + LLVM_DEBUG(dbgs() << "IC: Visiting: " << OrigI << '\n'); + + if (Instruction *Result = visit(*I)) { + ++NumCombined; + // Should we replace the old instruction with a new one? + if (Result != I) { + LLVM_DEBUG(dbgs() << "IC: Old = " << *I << '\n' + << " New = " << *Result << '\n'); + + Result->copyMetadata(*I, + {LLVMContext::MD_dbg, LLVMContext::MD_annotation}); + // Everything uses the new instruction now. + I->replaceAllUsesWith(Result); + + // Move the name to the new instruction first. + Result->takeName(I); + + // Insert the new instruction into the basic block... + BasicBlock *InstParent = I->getParent(); + BasicBlock::iterator InsertPos = I->getIterator(); + + // Are we replace a PHI with something that isn't a PHI, or vice versa? + if (isa<PHINode>(Result) != isa<PHINode>(I)) { + // We need to fix up the insertion point. + if (isa<PHINode>(I)) // PHI -> Non-PHI + InsertPos = InstParent->getFirstInsertionPt(); + else // Non-PHI -> PHI + InsertPos = InstParent->getFirstNonPHI()->getIterator(); + } + + InstParent->getInstList().insert(InsertPos, Result); + + // Push the new instruction and any users onto the worklist. + Worklist.pushUsersToWorkList(*Result); + Worklist.push(Result); + + eraseInstFromFunction(*I); + } else { + LLVM_DEBUG(dbgs() << "IC: Mod = " << OrigI << '\n' + << " New = " << *I << '\n'); + + // If the instruction was modified, it's possible that it is now dead. + // if so, remove it. + if (isInstructionTriviallyDead(I, &TLI)) { + eraseInstFromFunction(*I); + } else { + Worklist.pushUsersToWorkList(*I); + Worklist.push(I); + } + } + MadeIRChange = true; + } + } + + Worklist.zap(); + return MadeIRChange; +} + +// Track the scopes used by !alias.scope and !noalias. In a function, a +// @llvm.experimental.noalias.scope.decl is only useful if that scope is used +// by both sets. If not, the declaration of the scope can be safely omitted. +// The MDNode of the scope can be omitted as well for the instructions that are +// part of this function. We do not do that at this point, as this might become +// too time consuming to do. +class AliasScopeTracker { + SmallPtrSet<const MDNode *, 8> UsedAliasScopesAndLists; + SmallPtrSet<const MDNode *, 8> UsedNoAliasScopesAndLists; + +public: + void analyse(Instruction *I) { + // This seems to be faster than checking 'mayReadOrWriteMemory()'. + if (!I->hasMetadataOtherThanDebugLoc()) + return; + + auto Track = [](Metadata *ScopeList, auto &Container) { + const auto *MDScopeList = dyn_cast_or_null<MDNode>(ScopeList); + if (!MDScopeList || !Container.insert(MDScopeList).second) + return; + for (auto &MDOperand : MDScopeList->operands()) + if (auto *MDScope = dyn_cast<MDNode>(MDOperand)) + Container.insert(MDScope); + }; + + Track(I->getMetadata(LLVMContext::MD_alias_scope), UsedAliasScopesAndLists); + Track(I->getMetadata(LLVMContext::MD_noalias), UsedNoAliasScopesAndLists); + } + + bool isNoAliasScopeDeclDead(Instruction *Inst) { + NoAliasScopeDeclInst *Decl = dyn_cast<NoAliasScopeDeclInst>(Inst); + if (!Decl) + return false; + + assert(Decl->use_empty() && + "llvm.experimental.noalias.scope.decl in use ?"); + const MDNode *MDSL = Decl->getScopeList(); + assert(MDSL->getNumOperands() == 1 && + "llvm.experimental.noalias.scope should refer to a single scope"); + auto &MDOperand = MDSL->getOperand(0); + if (auto *MD = dyn_cast<MDNode>(MDOperand)) + return !UsedAliasScopesAndLists.contains(MD) || + !UsedNoAliasScopesAndLists.contains(MD); + + // Not an MDNode ? throw away. + return true; + } +}; + +/// Populate the IC worklist from a function, by walking it in depth-first +/// order and adding all reachable code to the worklist. +/// +/// This has a couple of tricks to make the code faster and more powerful. In +/// particular, we constant fold and DCE instructions as we go, to avoid adding +/// them to the worklist (this significantly speeds up instcombine on code where +/// many instructions are dead or constant). Additionally, if we find a branch +/// whose condition is a known constant, we only visit the reachable successors. +static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, + const TargetLibraryInfo *TLI, + InstructionWorklist &ICWorklist) { + bool MadeIRChange = false; + SmallPtrSet<BasicBlock *, 32> Visited; + SmallVector<BasicBlock*, 256> Worklist; + Worklist.push_back(&F.front()); + + SmallVector<Instruction *, 128> InstrsForInstructionWorklist; + DenseMap<Constant *, Constant *> FoldedConstants; + AliasScopeTracker SeenAliasScopes; + + do { + BasicBlock *BB = Worklist.pop_back_val(); + + // We have now visited this block! If we've already been here, ignore it. + if (!Visited.insert(BB).second) + continue; + + for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { + // ConstantProp instruction if trivially constant. + if (!Inst.use_empty() && + (Inst.getNumOperands() == 0 || isa<Constant>(Inst.getOperand(0)))) + if (Constant *C = ConstantFoldInstruction(&Inst, DL, TLI)) { + LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << Inst + << '\n'); + Inst.replaceAllUsesWith(C); + ++NumConstProp; + if (isInstructionTriviallyDead(&Inst, TLI)) + Inst.eraseFromParent(); + MadeIRChange = true; + continue; + } + + // See if we can constant fold its operands. + for (Use &U : Inst.operands()) { + if (!isa<ConstantVector>(U) && !isa<ConstantExpr>(U)) + continue; + + auto *C = cast<Constant>(U); + Constant *&FoldRes = FoldedConstants[C]; + if (!FoldRes) + FoldRes = ConstantFoldConstant(C, DL, TLI); + + if (FoldRes != C) { + LLVM_DEBUG(dbgs() << "IC: ConstFold operand of: " << Inst + << "\n Old = " << *C + << "\n New = " << *FoldRes << '\n'); + U = FoldRes; + MadeIRChange = true; + } + } + + // Skip processing debug and pseudo intrinsics in InstCombine. Processing + // these call instructions consumes non-trivial amount of time and + // provides no value for the optimization. + if (!Inst.isDebugOrPseudoInst()) { + InstrsForInstructionWorklist.push_back(&Inst); + SeenAliasScopes.analyse(&Inst); + } + } + + // Recursively visit successors. If this is a branch or switch on a + // constant, only visit the reachable successor. + Instruction *TI = BB->getTerminator(); + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + if (BI->isConditional() && isa<ConstantInt>(BI->getCondition())) { + bool CondVal = cast<ConstantInt>(BI->getCondition())->getZExtValue(); + BasicBlock *ReachableBB = BI->getSuccessor(!CondVal); + Worklist.push_back(ReachableBB); + continue; + } + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { + if (ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition())) { + Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor()); + continue; + } + } + + append_range(Worklist, successors(TI)); + } while (!Worklist.empty()); + + // Remove instructions inside unreachable blocks. This prevents the + // instcombine code from having to deal with some bad special cases, and + // reduces use counts of instructions. + for (BasicBlock &BB : F) { + if (Visited.count(&BB)) + continue; + + unsigned NumDeadInstInBB; + unsigned NumDeadDbgInstInBB; + std::tie(NumDeadInstInBB, NumDeadDbgInstInBB) = + removeAllNonTerminatorAndEHPadInstructions(&BB); + + MadeIRChange |= NumDeadInstInBB + NumDeadDbgInstInBB > 0; + NumDeadInst += NumDeadInstInBB; + } + + // Once we've found all of the instructions to add to instcombine's worklist, + // add them in reverse order. This way instcombine will visit from the top + // of the function down. This jives well with the way that it adds all uses + // of instructions to the worklist after doing a transformation, thus avoiding + // some N^2 behavior in pathological cases. + ICWorklist.reserve(InstrsForInstructionWorklist.size()); + for (Instruction *Inst : reverse(InstrsForInstructionWorklist)) { + // DCE instruction if trivially dead. As we iterate in reverse program + // order here, we will clean up whole chains of dead instructions. + if (isInstructionTriviallyDead(Inst, TLI) || + SeenAliasScopes.isNoAliasScopeDeclDead(Inst)) { + ++NumDeadInst; + LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); + salvageDebugInfo(*Inst); + Inst->eraseFromParent(); + MadeIRChange = true; + continue; + } + + ICWorklist.push(Inst); + } + + return MadeIRChange; +} + +static bool combineInstructionsOverFunction( + Function &F, InstructionWorklist &Worklist, AliasAnalysis *AA, + AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI, + DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, + ProfileSummaryInfo *PSI, unsigned MaxIterations, LoopInfo *LI) { + auto &DL = F.getParent()->getDataLayout(); + MaxIterations = std::min(MaxIterations, LimitMaxIterations.getValue()); + + /// Builder - This is an IRBuilder that automatically inserts new + /// instructions into the worklist when they are created. + IRBuilder<TargetFolder, IRBuilderCallbackInserter> Builder( + F.getContext(), TargetFolder(DL), + IRBuilderCallbackInserter([&Worklist, &AC](Instruction *I) { + Worklist.add(I); + if (auto *Assume = dyn_cast<AssumeInst>(I)) + AC.registerAssumption(Assume); + })); + + // Lower dbg.declare intrinsics otherwise their value may be clobbered + // by instcombiner. + bool MadeIRChange = false; + if (ShouldLowerDbgDeclare) + MadeIRChange = LowerDbgDeclare(F); + + // Iterate while there is work to do. + unsigned Iteration = 0; + while (true) { + ++NumWorklistIterations; + ++Iteration; + + if (Iteration > InfiniteLoopDetectionThreshold) { + report_fatal_error( + "Instruction Combining seems stuck in an infinite loop after " + + Twine(InfiniteLoopDetectionThreshold) + " iterations."); + } + + if (Iteration > MaxIterations) { + LLVM_DEBUG(dbgs() << "\n\n[IC] Iteration limit #" << MaxIterations + << " on " << F.getName() + << " reached; stopping before reaching a fixpoint\n"); + break; + } + + LLVM_DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " + << F.getName() << "\n"); + + MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist); + + InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT, + ORE, BFI, PSI, DL, LI); + IC.MaxArraySizeForCombine = MaxArraySize; + + if (!IC.run()) + break; + + MadeIRChange = true; + } + + return MadeIRChange; +} + +InstCombinePass::InstCombinePass() : MaxIterations(LimitMaxIterations) {} + +InstCombinePass::InstCombinePass(unsigned MaxIterations) + : MaxIterations(MaxIterations) {} + +PreservedAnalyses InstCombinePass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + + auto *LI = AM.getCachedResult<LoopAnalysis>(F); + + auto *AA = &AM.getResult<AAManager>(F); + auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); + ProfileSummaryInfo *PSI = + MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + auto *BFI = (PSI && PSI->hasProfileSummary()) ? + &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; + + if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, + BFI, PSI, MaxIterations, LI)) + // No changes, all analyses are preserved. + return PreservedAnalyses::all(); + + // Mark all the analyses that instcombine updates as preserved. + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; +} + +void InstructionCombiningPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<BasicAAWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addRequired<ProfileSummaryInfoWrapperPass>(); + LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); +} + +bool InstructionCombiningPass::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + // Required analyses. + auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + + // Optional analyses. + auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); + auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; + ProfileSummaryInfo *PSI = + &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + BlockFrequencyInfo *BFI = + (PSI && PSI->hasProfileSummary()) ? + &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() : + nullptr; + + return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, + BFI, PSI, MaxIterations, LI); +} + +char InstructionCombiningPass::ID = 0; + +InstructionCombiningPass::InstructionCombiningPass() + : FunctionPass(ID), MaxIterations(InstCombineDefaultMaxIterations) { + initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry()); +} + +InstructionCombiningPass::InstructionCombiningPass(unsigned MaxIterations) + : FunctionPass(ID), MaxIterations(MaxIterations) { + initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry()); +} + +INITIALIZE_PASS_BEGIN(InstructionCombiningPass, "instcombine", + "Combine redundant instructions", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) +INITIALIZE_PASS_END(InstructionCombiningPass, "instcombine", + "Combine redundant instructions", false, false) + +// Initialization Routines +void llvm::initializeInstCombine(PassRegistry &Registry) { + initializeInstructionCombiningPassPass(Registry); +} + +void LLVMInitializeInstCombine(LLVMPassRegistryRef R) { + initializeInstructionCombiningPassPass(*unwrap(R)); +} + +FunctionPass *llvm::createInstructionCombiningPass() { + return new InstructionCombiningPass(); +} + +FunctionPass *llvm::createInstructionCombiningPass(unsigned MaxIterations) { + return new InstructionCombiningPass(MaxIterations); +} + +void LLVMAddInstructionCombiningPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createInstructionCombiningPass()); +} |