diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineShifts.cpp')
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineShifts.cpp | 98 |
1 files changed, 68 insertions, 30 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index c562d45a9e2b..c821292400cd 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1,9 +1,8 @@ //===- InstCombineShifts.cpp ----------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -21,6 +20,51 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" +// Given pattern: +// (x shiftopcode Q) shiftopcode K +// we should rewrite it as +// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) +// This is valid for any shift, but they must be identical. +static Instruction * +reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0, + const SimplifyQuery &SQ) { + // Look for: (x shiftopcode ShAmt0) shiftopcode ShAmt1 + Value *X, *ShAmt1, *ShAmt0; + Instruction *Sh1; + if (!match(Sh0, m_Shift(m_CombineAnd(m_Shift(m_Value(X), m_Value(ShAmt1)), + m_Instruction(Sh1)), + m_Value(ShAmt0)))) + return nullptr; + + // The shift opcodes must be identical. + Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode(); + if (ShiftOpcode != Sh1->getOpcode()) + return nullptr; + // Can we fold (ShAmt0+ShAmt1) ? + Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, ShAmt0, ShAmt1, + SQ.getWithInstruction(Sh0)); + if (!NewShAmt) + return nullptr; // Did not simplify. + // Is the new shift amount smaller than the bit width? + // FIXME: could also rely on ConstantRange. + unsigned BitWidth = X->getType()->getScalarSizeInBits(); + if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, + APInt(BitWidth, BitWidth)))) + return nullptr; + // All good, we can do this fold. + BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt); + // If both of the original shifts had the same flag set, preserve the flag. + if (ShiftOpcode == Instruction::BinaryOps::Shl) { + NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() && + Sh1->hasNoUnsignedWrap()); + NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() && + Sh1->hasNoSignedWrap()); + } else { + NewShift->setIsExact(Sh0->isExact() && Sh1->isExact()); + } + return NewShift; +} + Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); assert(Op0->getType() == Op1->getType()); @@ -39,6 +83,10 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) return Res; + if (Instruction *NewShift = + reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ)) + return NewShift; + // (C1 shift (A add C2)) -> (C1 shift C2) shift A) // iff A and C2 are both positive. Value *A; @@ -313,35 +361,17 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, // If this is a bitwise operator or add with a constant RHS we might be able // to pull it through a shift. static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift, - BinaryOperator *BO, - const APInt &C) { - bool IsValid = true; // Valid only for And, Or Xor, - bool HighBitSet = false; // Transform ifhigh bit of constant set? - + BinaryOperator *BO) { switch (BO->getOpcode()) { - default: IsValid = false; break; // Do not perform transform! + default: + return false; // Do not perform transform! case Instruction::Add: - IsValid = Shift.getOpcode() == Instruction::Shl; - break; + return Shift.getOpcode() == Instruction::Shl; case Instruction::Or: case Instruction::Xor: - HighBitSet = false; - break; case Instruction::And: - HighBitSet = true; - break; + return true; } - - // If this is a signed shift right, and the high bit is modified - // by the logical operation, do not perform the transformation. - // The HighBitSet boolean indicates the value of the high bit of - // the constant which would cause it to be modified for this - // operation. - // - if (IsValid && Shift.getOpcode() == Instruction::AShr) - IsValid = C.isNegative() == HighBitSet; - - return IsValid; } Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, @@ -508,7 +538,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // shift is the only use, we can pull it out of the shift. const APInt *Op0C; if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { - if (canShiftBinOpWithConstantRHS(I, Op0BO, *Op0C)) { + if (canShiftBinOpWithConstantRHS(I, Op0BO)) { Constant *NewRHS = ConstantExpr::get(I.getOpcode(), cast<Constant>(Op0BO->getOperand(1)), Op1); @@ -552,7 +582,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, const APInt *C; if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && match(TBO->getOperand(1), m_APInt(C)) && - canShiftBinOpWithConstantRHS(I, TBO, *C)) { + canShiftBinOpWithConstantRHS(I, TBO)) { Constant *NewRHS = ConstantExpr::get(I.getOpcode(), cast<Constant>(TBO->getOperand(1)), Op1); @@ -571,7 +601,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, const APInt *C; if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && match(FBO->getOperand(1), m_APInt(C)) && - canShiftBinOpWithConstantRHS(I, FBO, *C)) { + canShiftBinOpWithConstantRHS(I, FBO)) { Constant *NewRHS = ConstantExpr::get(I.getOpcode(), cast<Constant>(FBO->getOperand(1)), Op1); @@ -601,6 +631,8 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + const APInt *ShAmtAPInt; if (match(Op1, m_APInt(ShAmtAPInt))) { unsigned ShAmt = ShAmtAPInt->getZExtValue(); @@ -689,6 +721,12 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); } + // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 + if (match(Op0, m_One()) && + match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X)))) + return BinaryOperator::CreateLShr( + ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); + return nullptr; } |
