summaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineShifts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineShifts.cpp98
1 files changed, 68 insertions, 30 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp
index c562d45a9e2b..c821292400cd 100644
--- a/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1,9 +1,8 @@
//===- InstCombineShifts.cpp ----------------------------------------------===//
//
-// The LLVM Compiler Infrastructure
-//
-// This file is distributed under the University of Illinois Open Source
-// License. See LICENSE.TXT for details.
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
@@ -21,6 +20,51 @@ using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
+// Given pattern:
+// (x shiftopcode Q) shiftopcode K
+// we should rewrite it as
+// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x)
+// This is valid for any shift, but they must be identical.
+static Instruction *
+reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0,
+ const SimplifyQuery &SQ) {
+ // Look for: (x shiftopcode ShAmt0) shiftopcode ShAmt1
+ Value *X, *ShAmt1, *ShAmt0;
+ Instruction *Sh1;
+ if (!match(Sh0, m_Shift(m_CombineAnd(m_Shift(m_Value(X), m_Value(ShAmt1)),
+ m_Instruction(Sh1)),
+ m_Value(ShAmt0))))
+ return nullptr;
+
+ // The shift opcodes must be identical.
+ Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode();
+ if (ShiftOpcode != Sh1->getOpcode())
+ return nullptr;
+ // Can we fold (ShAmt0+ShAmt1) ?
+ Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, ShAmt0, ShAmt1,
+ SQ.getWithInstruction(Sh0));
+ if (!NewShAmt)
+ return nullptr; // Did not simplify.
+ // Is the new shift amount smaller than the bit width?
+ // FIXME: could also rely on ConstantRange.
+ unsigned BitWidth = X->getType()->getScalarSizeInBits();
+ if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT,
+ APInt(BitWidth, BitWidth))))
+ return nullptr;
+ // All good, we can do this fold.
+ BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt);
+ // If both of the original shifts had the same flag set, preserve the flag.
+ if (ShiftOpcode == Instruction::BinaryOps::Shl) {
+ NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() &&
+ Sh1->hasNoUnsignedWrap());
+ NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() &&
+ Sh1->hasNoSignedWrap());
+ } else {
+ NewShift->setIsExact(Sh0->isExact() && Sh1->isExact());
+ }
+ return NewShift;
+}
+
Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
assert(Op0->getType() == Op1->getType());
@@ -39,6 +83,10 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
return Res;
+ if (Instruction *NewShift =
+ reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ))
+ return NewShift;
+
// (C1 shift (A add C2)) -> (C1 shift C2) shift A)
// iff A and C2 are both positive.
Value *A;
@@ -313,35 +361,17 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
// If this is a bitwise operator or add with a constant RHS we might be able
// to pull it through a shift.
static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift,
- BinaryOperator *BO,
- const APInt &C) {
- bool IsValid = true; // Valid only for And, Or Xor,
- bool HighBitSet = false; // Transform ifhigh bit of constant set?
-
+ BinaryOperator *BO) {
switch (BO->getOpcode()) {
- default: IsValid = false; break; // Do not perform transform!
+ default:
+ return false; // Do not perform transform!
case Instruction::Add:
- IsValid = Shift.getOpcode() == Instruction::Shl;
- break;
+ return Shift.getOpcode() == Instruction::Shl;
case Instruction::Or:
case Instruction::Xor:
- HighBitSet = false;
- break;
case Instruction::And:
- HighBitSet = true;
- break;
+ return true;
}
-
- // If this is a signed shift right, and the high bit is modified
- // by the logical operation, do not perform the transformation.
- // The HighBitSet boolean indicates the value of the high bit of
- // the constant which would cause it to be modified for this
- // operation.
- //
- if (IsValid && Shift.getOpcode() == Instruction::AShr)
- IsValid = C.isNegative() == HighBitSet;
-
- return IsValid;
}
Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
@@ -508,7 +538,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// shift is the only use, we can pull it out of the shift.
const APInt *Op0C;
if (match(Op0BO->getOperand(1), m_APInt(Op0C))) {
- if (canShiftBinOpWithConstantRHS(I, Op0BO, *Op0C)) {
+ if (canShiftBinOpWithConstantRHS(I, Op0BO)) {
Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
cast<Constant>(Op0BO->getOperand(1)), Op1);
@@ -552,7 +582,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
const APInt *C;
if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal &&
match(TBO->getOperand(1), m_APInt(C)) &&
- canShiftBinOpWithConstantRHS(I, TBO, *C)) {
+ canShiftBinOpWithConstantRHS(I, TBO)) {
Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
cast<Constant>(TBO->getOperand(1)), Op1);
@@ -571,7 +601,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
const APInt *C;
if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal &&
match(FBO->getOperand(1), m_APInt(C)) &&
- canShiftBinOpWithConstantRHS(I, FBO, *C)) {
+ canShiftBinOpWithConstantRHS(I, FBO)) {
Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
cast<Constant>(FBO->getOperand(1)), Op1);
@@ -601,6 +631,8 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Type *Ty = I.getType();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+
const APInt *ShAmtAPInt;
if (match(Op1, m_APInt(ShAmtAPInt))) {
unsigned ShAmt = ShAmtAPInt->getZExtValue();
@@ -689,6 +721,12 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1));
}
+ // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1
+ if (match(Op0, m_One()) &&
+ match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X))))
+ return BinaryOperator::CreateLShr(
+ ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X);
+
return nullptr;
}