diff options
Diffstat (limited to 'llvm/lib/Transforms/AggressiveInstCombine')
3 files changed, 959 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp new file mode 100644 index 000000000000..a24de3ca213f --- /dev/null +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -0,0 +1,417 @@ +//===- AggressiveInstCombine.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the aggressive expression pattern combiner classes. +// Currently, it handles expression patterns for: +// * Truncate instruction +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" +#include "AggressiveInstCombineInternal.h" +#include "llvm-c/Initialization.h" +#include "llvm-c/Transforms/AggressiveInstCombine.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/Local.h" +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "aggressive-instcombine" + +namespace { +/// Contains expression pattern combiner logic. +/// This class provides both the logic to combine expression patterns and +/// combine them. It differs from InstCombiner class in that each pattern +/// combiner runs only once as opposed to InstCombine's multi-iteration, +/// which allows pattern combiner to have higher complexity than the O(1) +/// required by the instruction combiner. +class AggressiveInstCombinerLegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + + AggressiveInstCombinerLegacyPass() : FunctionPass(ID) { + initializeAggressiveInstCombinerLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override; + + /// Run all expression pattern optimizations on the given /p F function. + /// + /// \param F function to optimize. + /// \returns true if the IR is changed. + bool runOnFunction(Function &F) override; +}; +} // namespace + +/// Match a pattern for a bitwise rotate operation that partially guards +/// against undefined behavior by branching around the rotation when the shift +/// amount is 0. +static bool foldGuardedRotateToFunnelShift(Instruction &I) { + if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2) + return false; + + // As with the one-use checks below, this is not strictly necessary, but we + // are being cautious to avoid potential perf regressions on targets that + // do not actually have a rotate instruction (where the funnel shift would be + // expanded back into math/shift/logic ops). + if (!isPowerOf2_32(I.getType()->getScalarSizeInBits())) + return false; + + // Match V to funnel shift left/right and capture the source operand and + // shift amount in X and Y. + auto matchRotate = [](Value *V, Value *&X, Value *&Y) { + Value *L0, *L1, *R0, *R1; + unsigned Width = V->getType()->getScalarSizeInBits(); + auto Sub = m_Sub(m_SpecificInt(Width), m_Value(R1)); + + // rotate_left(X, Y) == (X << Y) | (X >> (Width - Y)) + auto RotL = m_OneUse( + m_c_Or(m_Shl(m_Value(L0), m_Value(L1)), m_LShr(m_Value(R0), Sub))); + if (RotL.match(V) && L0 == R0 && L1 == R1) { + X = L0; + Y = L1; + return Intrinsic::fshl; + } + + // rotate_right(X, Y) == (X >> Y) | (X << (Width - Y)) + auto RotR = m_OneUse( + m_c_Or(m_LShr(m_Value(L0), m_Value(L1)), m_Shl(m_Value(R0), Sub))); + if (RotR.match(V) && L0 == R0 && L1 == R1) { + X = L0; + Y = L1; + return Intrinsic::fshr; + } + + return Intrinsic::not_intrinsic; + }; + + // One phi operand must be a rotate operation, and the other phi operand must + // be the source value of that rotate operation: + // phi [ rotate(RotSrc, RotAmt), RotBB ], [ RotSrc, GuardBB ] + PHINode &Phi = cast<PHINode>(I); + Value *P0 = Phi.getOperand(0), *P1 = Phi.getOperand(1); + Value *RotSrc, *RotAmt; + Intrinsic::ID IID = matchRotate(P0, RotSrc, RotAmt); + if (IID == Intrinsic::not_intrinsic || RotSrc != P1) { + IID = matchRotate(P1, RotSrc, RotAmt); + if (IID == Intrinsic::not_intrinsic || RotSrc != P0) + return false; + assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && + "Pattern must match funnel shift left or right"); + } + + // The incoming block with our source operand must be the "guard" block. + // That must contain a cmp+branch to avoid the rotate when the shift amount + // is equal to 0. The other incoming block is the block with the rotate. + BasicBlock *GuardBB = Phi.getIncomingBlock(RotSrc == P1); + BasicBlock *RotBB = Phi.getIncomingBlock(RotSrc != P1); + Instruction *TermI = GuardBB->getTerminator(); + ICmpInst::Predicate Pred; + BasicBlock *PhiBB = Phi.getParent(); + if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(RotAmt), m_ZeroInt()), + m_SpecificBB(PhiBB), m_SpecificBB(RotBB)))) + return false; + + if (Pred != CmpInst::ICMP_EQ) + return false; + + // We matched a variation of this IR pattern: + // GuardBB: + // %cmp = icmp eq i32 %RotAmt, 0 + // br i1 %cmp, label %PhiBB, label %RotBB + // RotBB: + // %sub = sub i32 32, %RotAmt + // %shr = lshr i32 %X, %sub + // %shl = shl i32 %X, %RotAmt + // %rot = or i32 %shr, %shl + // br label %PhiBB + // PhiBB: + // %cond = phi i32 [ %rot, %RotBB ], [ %X, %GuardBB ] + // --> + // llvm.fshl.i32(i32 %X, i32 %RotAmt) + IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt()); + Function *F = Intrinsic::getDeclaration(Phi.getModule(), IID, Phi.getType()); + Phi.replaceAllUsesWith(Builder.CreateCall(F, {RotSrc, RotSrc, RotAmt})); + return true; +} + +/// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and +/// the bit indexes (Mask) needed by a masked compare. If we're matching a chain +/// of 'and' ops, then we also need to capture the fact that we saw an +/// "and X, 1", so that's an extra return value for that case. +struct MaskOps { + Value *Root; + APInt Mask; + bool MatchAndChain; + bool FoundAnd1; + + MaskOps(unsigned BitWidth, bool MatchAnds) + : Root(nullptr), Mask(APInt::getNullValue(BitWidth)), + MatchAndChain(MatchAnds), FoundAnd1(false) {} +}; + +/// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a +/// chain of 'and' or 'or' instructions looking for shift ops of a common source +/// value. Examples: +/// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8) +/// returns { X, 0x129 } +/// and (and (X >> 1), 1), (X >> 4) +/// returns { X, 0x12 } +static bool matchAndOrChain(Value *V, MaskOps &MOps) { + Value *Op0, *Op1; + if (MOps.MatchAndChain) { + // Recurse through a chain of 'and' operands. This requires an extra check + // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere + // in the chain to know that all of the high bits are cleared. + if (match(V, m_And(m_Value(Op0), m_One()))) { + MOps.FoundAnd1 = true; + return matchAndOrChain(Op0, MOps); + } + if (match(V, m_And(m_Value(Op0), m_Value(Op1)))) + return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps); + } else { + // Recurse through a chain of 'or' operands. + if (match(V, m_Or(m_Value(Op0), m_Value(Op1)))) + return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps); + } + + // We need a shift-right or a bare value representing a compare of bit 0 of + // the original source operand. + Value *Candidate; + uint64_t BitIndex = 0; + if (!match(V, m_LShr(m_Value(Candidate), m_ConstantInt(BitIndex)))) + Candidate = V; + + // Initialize result source operand. + if (!MOps.Root) + MOps.Root = Candidate; + + // The shift constant is out-of-range? This code hasn't been simplified. + if (BitIndex >= MOps.Mask.getBitWidth()) + return false; + + // Fill in the mask bit derived from the shift constant. + MOps.Mask.setBit(BitIndex); + return MOps.Root == Candidate; +} + +/// Match patterns that correspond to "any-bits-set" and "all-bits-set". +/// These will include a chain of 'or' or 'and'-shifted bits from a +/// common source value: +/// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0 +/// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask +/// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns +/// that differ only with a final 'not' of the result. We expect that final +/// 'not' to be folded with the compare that we create here (invert predicate). +static bool foldAnyOrAllBitsSet(Instruction &I) { + // The 'any-bits-set' ('or' chain) pattern is simpler to match because the + // final "and X, 1" instruction must be the final op in the sequence. + bool MatchAllBitsSet; + if (match(&I, m_c_And(m_OneUse(m_And(m_Value(), m_Value())), m_Value()))) + MatchAllBitsSet = true; + else if (match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One()))) + MatchAllBitsSet = false; + else + return false; + + MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet); + if (MatchAllBitsSet) { + if (!matchAndOrChain(cast<BinaryOperator>(&I), MOps) || !MOps.FoundAnd1) + return false; + } else { + if (!matchAndOrChain(cast<BinaryOperator>(&I)->getOperand(0), MOps)) + return false; + } + + // The pattern was found. Create a masked compare that replaces all of the + // shift and logic ops. + IRBuilder<> Builder(&I); + Constant *Mask = ConstantInt::get(I.getType(), MOps.Mask); + Value *And = Builder.CreateAnd(MOps.Root, Mask); + Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask) + : Builder.CreateIsNotNull(And); + Value *Zext = Builder.CreateZExt(Cmp, I.getType()); + I.replaceAllUsesWith(Zext); + return true; +} + +// Try to recognize below function as popcount intrinsic. +// This is the "best" algorithm from +// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel +// Also used in TargetLowering::expandCTPOP(). +// +// int popcount(unsigned int i) { +// i = i - ((i >> 1) & 0x55555555); +// i = (i & 0x33333333) + ((i >> 2) & 0x33333333); +// i = ((i + (i >> 4)) & 0x0F0F0F0F); +// return (i * 0x01010101) >> 24; +// } +static bool tryToRecognizePopCount(Instruction &I) { + if (I.getOpcode() != Instruction::LShr) + return false; + + Type *Ty = I.getType(); + if (!Ty->isIntOrIntVectorTy()) + return false; + + unsigned Len = Ty->getScalarSizeInBits(); + // FIXME: fix Len == 8 and other irregular type lengths. + if (!(Len <= 128 && Len > 8 && Len % 8 == 0)) + return false; + + APInt Mask55 = APInt::getSplat(Len, APInt(8, 0x55)); + APInt Mask33 = APInt::getSplat(Len, APInt(8, 0x33)); + APInt Mask0F = APInt::getSplat(Len, APInt(8, 0x0F)); + APInt Mask01 = APInt::getSplat(Len, APInt(8, 0x01)); + APInt MaskShift = APInt(Len, Len - 8); + + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + Value *MulOp0; + // Matching "(i * 0x01010101...) >> 24". + if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) && + match(Op1, m_SpecificInt(MaskShift))) { + Value *ShiftOp0; + // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)". + if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)), + m_Deferred(ShiftOp0)), + m_SpecificInt(Mask0F)))) { + Value *AndOp0; + // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)". + if (match(ShiftOp0, + m_c_Add(m_And(m_Value(AndOp0), m_SpecificInt(Mask33)), + m_And(m_LShr(m_Deferred(AndOp0), m_SpecificInt(2)), + m_SpecificInt(Mask33))))) { + Value *Root, *SubOp1; + // Matching "i - ((i >> 1) & 0x55555555...)". + if (match(AndOp0, m_Sub(m_Value(Root), m_Value(SubOp1))) && + match(SubOp1, m_And(m_LShr(m_Specific(Root), m_SpecificInt(1)), + m_SpecificInt(Mask55)))) { + LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n"); + IRBuilder<> Builder(&I); + Function *Func = Intrinsic::getDeclaration( + I.getModule(), Intrinsic::ctpop, I.getType()); + I.replaceAllUsesWith(Builder.CreateCall(Func, {Root})); + return true; + } + } + } + } + + return false; +} + +/// This is the entry point for folds that could be implemented in regular +/// InstCombine, but they are separated because they are not expected to +/// occur frequently and/or have more than a constant-length pattern match. +static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { + bool MadeChange = false; + for (BasicBlock &BB : F) { + // Ignore unreachable basic blocks. + if (!DT.isReachableFromEntry(&BB)) + continue; + // Do not delete instructions under here and invalidate the iterator. + // Walk the block backwards for efficiency. We're matching a chain of + // use->defs, so we're more likely to succeed by starting from the bottom. + // Also, we want to avoid matching partial patterns. + // TODO: It would be more efficient if we removed dead instructions + // iteratively in this loop rather than waiting until the end. + for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { + MadeChange |= foldAnyOrAllBitsSet(I); + MadeChange |= foldGuardedRotateToFunnelShift(I); + MadeChange |= tryToRecognizePopCount(I); + } + } + + // We're done with transforms, so remove dead instructions. + if (MadeChange) + for (BasicBlock &BB : F) + SimplifyInstructionsInBlock(&BB); + + return MadeChange; +} + +/// This is the entry point for all transforms. Pass manager differences are +/// handled in the callers of this function. +static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT) { + bool MadeChange = false; + const DataLayout &DL = F.getParent()->getDataLayout(); + TruncInstCombine TIC(TLI, DL, DT); + MadeChange |= TIC.run(F); + MadeChange |= foldUnusualPatterns(F, DT); + return MadeChange; +} + +void AggressiveInstCombinerLegacyPass::getAnalysisUsage( + AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<BasicAAWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); +} + +bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) { + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + return runImpl(F, TLI, DT); +} + +PreservedAnalyses AggressiveInstCombinePass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + if (!runImpl(F, TLI, DT)) { + // No changes, all analyses are preserved. + return PreservedAnalyses::all(); + } + // Mark all the analyses that instcombine updates as preserved. + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<AAManager>(); + PA.preserve<GlobalsAA>(); + return PA; +} + +char AggressiveInstCombinerLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(AggressiveInstCombinerLegacyPass, + "aggressive-instcombine", + "Combine pattern based expressions", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(AggressiveInstCombinerLegacyPass, "aggressive-instcombine", + "Combine pattern based expressions", false, false) + +// Initialization Routines +void llvm::initializeAggressiveInstCombine(PassRegistry &Registry) { + initializeAggressiveInstCombinerLegacyPassPass(Registry); +} + +void LLVMInitializeAggressiveInstCombiner(LLVMPassRegistryRef R) { + initializeAggressiveInstCombinerLegacyPassPass(*unwrap(R)); +} + +FunctionPass *llvm::createAggressiveInstCombinerPass() { + return new AggressiveInstCombinerLegacyPass(); +} + +void LLVMAddAggressiveInstCombinerPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createAggressiveInstCombinerPass()); +} diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h new file mode 100644 index 000000000000..44e1c45664e7 --- /dev/null +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h @@ -0,0 +1,125 @@ +//===- AggressiveInstCombineInternal.h --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the instruction pattern combiner classes. +// Currently, it handles pattern expressions for: +// * Truncate instruction +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TRANSFORMS_AGGRESSIVEINSTCOMBINE_COMBINEINTERNAL_H +#define LLVM_LIB_TRANSFORMS_AGGRESSIVEINSTCOMBINE_COMBINEINTERNAL_H + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/Pass.h" +using namespace llvm; + +//===----------------------------------------------------------------------===// +// TruncInstCombine - looks for expression dags dominated by trunc instructions +// and for each eligible dag, it will create a reduced bit-width expression and +// replace the old expression with this new one and remove the old one. +// Eligible expression dag is such that: +// 1. Contains only supported instructions. +// 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value. +// 3. Can be evaluated into type with reduced legal bit-width (or Trunc type). +// 4. All instructions in the dag must not have users outside the dag. +// Only exception is for {ZExt, SExt}Inst with operand type equal to the +// new reduced type chosen in (3). +// +// The motivation for this optimization is that evaluating and expression using +// smaller bit-width is preferable, especially for vectorization where we can +// fit more values in one vectorized instruction. In addition, this optimization +// may decrease the number of cast instructions, but will not increase it. +//===----------------------------------------------------------------------===// + +namespace llvm { + class DataLayout; + class DominatorTree; + class TargetLibraryInfo; + +class TruncInstCombine { + TargetLibraryInfo &TLI; + const DataLayout &DL; + const DominatorTree &DT; + + /// List of all TruncInst instructions to be processed. + SmallVector<TruncInst *, 4> Worklist; + + /// Current processed TruncInst instruction. + TruncInst *CurrentTruncInst; + + /// Information per each instruction in the expression dag. + struct Info { + /// Number of LSBs that are needed to generate a valid expression. + unsigned ValidBitWidth = 0; + /// Minimum number of LSBs needed to generate the ValidBitWidth. + unsigned MinBitWidth = 0; + /// The reduced value generated to replace the old instruction. + Value *NewValue = nullptr; + }; + /// An ordered map representing expression dag post-dominated by current + /// processed TruncInst. It maps each instruction in the dag to its Info + /// structure. The map is ordered such that each instruction appears before + /// all other instructions in the dag that uses it. + MapVector<Instruction *, Info> InstInfoMap; + +public: + TruncInstCombine(TargetLibraryInfo &TLI, const DataLayout &DL, + const DominatorTree &DT) + : TLI(TLI), DL(DL), DT(DT), CurrentTruncInst(nullptr) {} + + /// Perform TruncInst pattern optimization on given function. + bool run(Function &F); + +private: + /// Build expression dag dominated by the /p CurrentTruncInst and append it to + /// the InstInfoMap container. + /// + /// \return true only if succeed to generate an eligible sub expression dag. + bool buildTruncExpressionDag(); + + /// Calculate the minimal allowed bit-width of the chain ending with the + /// currently visited truncate's operand. + /// + /// \return minimum number of bits to which the chain ending with the + /// truncate's operand can be shrunk to. + unsigned getMinBitWidth(); + + /// Build an expression dag dominated by the current processed TruncInst and + /// Check if it is eligible to be reduced to a smaller type. + /// + /// \return the scalar version of the new type to be used for the reduced + /// expression dag, or nullptr if the expression dag is not eligible + /// to be reduced. + Type *getBestTruncatedType(); + + /// Given a \p V value and a \p SclTy scalar type return the generated reduced + /// value of \p V based on the type \p SclTy. + /// + /// \param V value to be reduced. + /// \param SclTy scalar version of new type to reduce to. + /// \return the new reduced value. + Value *getReducedOperand(Value *V, Type *SclTy); + + /// Create a new expression dag using the reduced /p SclTy type and replace + /// the old expression dag with it. Also erase all instructions in the old + /// dag, except those that are still needed outside the dag. + /// + /// \param SclTy scalar version of new type to reduce expression dag into. + void ReduceExpressionDag(Type *SclTy); +}; +} // end namespace llvm. + +#endif diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp new file mode 100644 index 000000000000..7c5767912fd3 --- /dev/null +++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -0,0 +1,417 @@ +//===- TruncInstCombine.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TruncInstCombine - looks for expression dags post-dominated by TruncInst and +// for each eligible dag, it will create a reduced bit-width expression, replace +// the old expression with this new one and remove the old expression. +// Eligible expression dag is such that: +// 1. Contains only supported instructions. +// 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value. +// 3. Can be evaluated into type with reduced legal bit-width. +// 4. All instructions in the dag must not have users outside the dag. +// The only exception is for {ZExt, SExt}Inst with operand type equal to +// the new reduced type evaluated in (3). +// +// The motivation for this optimization is that evaluating and expression using +// smaller bit-width is preferable, especially for vectorization where we can +// fit more values in one vectorized instruction. In addition, this optimization +// may decrease the number of cast instructions, but will not increase it. +// +//===----------------------------------------------------------------------===// + +#include "AggressiveInstCombineInternal.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +using namespace llvm; + +#define DEBUG_TYPE "aggressive-instcombine" + +/// Given an instruction and a container, it fills all the relevant operands of +/// that instruction, with respect to the Trunc expression dag optimizaton. +static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) { + unsigned Opc = I->getOpcode(); + switch (Opc) { + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + // These CastInst are considered leaves of the evaluated expression, thus, + // their operands are not relevent. + break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + Ops.push_back(I->getOperand(0)); + Ops.push_back(I->getOperand(1)); + break; + default: + llvm_unreachable("Unreachable!"); + } +} + +bool TruncInstCombine::buildTruncExpressionDag() { + SmallVector<Value *, 8> Worklist; + SmallVector<Instruction *, 8> Stack; + // Clear old expression dag. + InstInfoMap.clear(); + + Worklist.push_back(CurrentTruncInst->getOperand(0)); + + while (!Worklist.empty()) { + Value *Curr = Worklist.back(); + + if (isa<Constant>(Curr)) { + Worklist.pop_back(); + continue; + } + + auto *I = dyn_cast<Instruction>(Curr); + if (!I) + return false; + + if (!Stack.empty() && Stack.back() == I) { + // Already handled all instruction operands, can remove it from both the + // Worklist and the Stack, and add it to the instruction info map. + Worklist.pop_back(); + Stack.pop_back(); + // Insert I to the Info map. + InstInfoMap.insert(std::make_pair(I, Info())); + continue; + } + + if (InstInfoMap.count(I)) { + Worklist.pop_back(); + continue; + } + + // Add the instruction to the stack before start handling its operands. + Stack.push_back(I); + + unsigned Opc = I->getOpcode(); + switch (Opc) { + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + // trunc(trunc(x)) -> trunc(x) + // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest + // trunc(ext(x)) -> trunc(x) if the source type is larger than the new + // dest + break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + SmallVector<Value *, 2> Operands; + getRelevantOperands(I, Operands); + for (Value *Operand : Operands) + Worklist.push_back(Operand); + break; + } + default: + // TODO: Can handle more cases here: + // 1. select, shufflevector, extractelement, insertelement + // 2. udiv, urem + // 3. shl, lshr, ashr + // 4. phi node(and loop handling) + // ... + return false; + } + } + return true; +} + +unsigned TruncInstCombine::getMinBitWidth() { + SmallVector<Value *, 8> Worklist; + SmallVector<Instruction *, 8> Stack; + + Value *Src = CurrentTruncInst->getOperand(0); + Type *DstTy = CurrentTruncInst->getType(); + unsigned TruncBitWidth = DstTy->getScalarSizeInBits(); + unsigned OrigBitWidth = + CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); + + if (isa<Constant>(Src)) + return TruncBitWidth; + + Worklist.push_back(Src); + InstInfoMap[cast<Instruction>(Src)].ValidBitWidth = TruncBitWidth; + + while (!Worklist.empty()) { + Value *Curr = Worklist.back(); + + if (isa<Constant>(Curr)) { + Worklist.pop_back(); + continue; + } + + // Otherwise, it must be an instruction. + auto *I = cast<Instruction>(Curr); + + auto &Info = InstInfoMap[I]; + + SmallVector<Value *, 2> Operands; + getRelevantOperands(I, Operands); + + if (!Stack.empty() && Stack.back() == I) { + // Already handled all instruction operands, can remove it from both, the + // Worklist and the Stack, and update MinBitWidth. + Worklist.pop_back(); + Stack.pop_back(); + for (auto *Operand : Operands) + if (auto *IOp = dyn_cast<Instruction>(Operand)) + Info.MinBitWidth = + std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth); + continue; + } + + // Add the instruction to the stack before start handling its operands. + Stack.push_back(I); + unsigned ValidBitWidth = Info.ValidBitWidth; + + // Update minimum bit-width before handling its operands. This is required + // when the instruction is part of a loop. + Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth); + + for (auto *Operand : Operands) + if (auto *IOp = dyn_cast<Instruction>(Operand)) { + // If we already calculated the minimum bit-width for this valid + // bit-width, or for a smaller valid bit-width, then just keep the + // answer we already calculated. + unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth; + if (IOpBitwidth >= ValidBitWidth) + continue; + InstInfoMap[IOp].ValidBitWidth = std::max(ValidBitWidth, IOpBitwidth); + Worklist.push_back(IOp); + } + } + unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth; + assert(MinBitWidth >= TruncBitWidth); + + if (MinBitWidth > TruncBitWidth) { + // In this case reducing expression with vector type might generate a new + // vector type, which is not preferable as it might result in generating + // sub-optimal code. + if (DstTy->isVectorTy()) + return OrigBitWidth; + // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth). + Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth); + // Update minimum bit-width with the new destination type bit-width if + // succeeded to find such, otherwise, with original bit-width. + MinBitWidth = Ty ? Ty->getScalarSizeInBits() : OrigBitWidth; + } else { // MinBitWidth == TruncBitWidth + // In this case the expression can be evaluated with the trunc instruction + // destination type, and trunc instruction can be omitted. However, we + // should not perform the evaluation if the original type is a legal scalar + // type and the target type is illegal. + bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth); + bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth); + if (!DstTy->isVectorTy() && FromLegal && !ToLegal) + return OrigBitWidth; + } + return MinBitWidth; +} + +Type *TruncInstCombine::getBestTruncatedType() { + if (!buildTruncExpressionDag()) + return nullptr; + + // We don't want to duplicate instructions, which isn't profitable. Thus, we + // can't shrink something that has multiple users, unless all users are + // post-dominated by the trunc instruction, i.e., were visited during the + // expression evaluation. + unsigned DesiredBitWidth = 0; + for (auto Itr : InstInfoMap) { + Instruction *I = Itr.first; + if (I->hasOneUse()) + continue; + bool IsExtInst = (isa<ZExtInst>(I) || isa<SExtInst>(I)); + for (auto *U : I->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) { + if (!IsExtInst) + return nullptr; + // If this is an extension from the dest type, we can eliminate it, + // even if it has multiple users. Thus, update the DesiredBitWidth and + // validate all extension instructions agrees on same DesiredBitWidth. + unsigned ExtInstBitWidth = + I->getOperand(0)->getType()->getScalarSizeInBits(); + if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth) + return nullptr; + DesiredBitWidth = ExtInstBitWidth; + } + } + + unsigned OrigBitWidth = + CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); + + // Calculate minimum allowed bit-width allowed for shrinking the currently + // visited truncate's operand. + unsigned MinBitWidth = getMinBitWidth(); + + // Check that we can shrink to smaller bit-width than original one and that + // it is similar to the DesiredBitWidth is such exists. + if (MinBitWidth >= OrigBitWidth || + (DesiredBitWidth && DesiredBitWidth != MinBitWidth)) + return nullptr; + + return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth); +} + +/// Given a reduced scalar type \p Ty and a \p V value, return a reduced type +/// for \p V, according to its type, if it vector type, return the vector +/// version of \p Ty, otherwise return \p Ty. +static Type *getReducedType(Value *V, Type *Ty) { + assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type"); + if (auto *VTy = dyn_cast<VectorType>(V->getType())) + return VectorType::get(Ty, VTy->getNumElements()); + return Ty; +} + +Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) { + Type *Ty = getReducedType(V, SclTy); + if (auto *C = dyn_cast<Constant>(V)) { + C = ConstantExpr::getIntegerCast(C, Ty, false); + // If we got a constantexpr back, try to simplify it with DL info. + if (Constant *FoldedC = ConstantFoldConstant(C, DL, &TLI)) + C = FoldedC; + return C; + } + + auto *I = cast<Instruction>(V); + Info Entry = InstInfoMap.lookup(I); + assert(Entry.NewValue); + return Entry.NewValue; +} + +void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { + for (auto &Itr : InstInfoMap) { // Forward + Instruction *I = Itr.first; + TruncInstCombine::Info &NodeInfo = Itr.second; + + assert(!NodeInfo.NewValue && "Instruction has been evaluated"); + + IRBuilder<> Builder(I); + Value *Res = nullptr; + unsigned Opc = I->getOpcode(); + switch (Opc) { + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: { + Type *Ty = getReducedType(I, SclTy); + // If the source type of the cast is the type we're trying for then we can + // just return the source. There's no need to insert it because it is not + // new. + if (I->getOperand(0)->getType() == Ty) { + assert(!isa<TruncInst>(I) && "Cannot reach here with TruncInst"); + NodeInfo.NewValue = I->getOperand(0); + continue; + } + // Otherwise, must be the same type of cast, so just reinsert a new one. + // This also handles the case of zext(trunc(x)) -> zext(x). + Res = Builder.CreateIntCast(I->getOperand(0), Ty, + Opc == Instruction::SExt); + + // Update Worklist entries with new value if needed. + // There are three possible changes to the Worklist: + // 1. Update Old-TruncInst -> New-TruncInst. + // 2. Remove Old-TruncInst (if New node is not TruncInst). + // 3. Add New-TruncInst (if Old node was not TruncInst). + auto Entry = find(Worklist, I); + if (Entry != Worklist.end()) { + if (auto *NewCI = dyn_cast<TruncInst>(Res)) + *Entry = NewCI; + else + Worklist.erase(Entry); + } else if (auto *NewCI = dyn_cast<TruncInst>(Res)) + Worklist.push_back(NewCI); + break; + } + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + Value *LHS = getReducedOperand(I->getOperand(0), SclTy); + Value *RHS = getReducedOperand(I->getOperand(1), SclTy); + Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); + break; + } + default: + llvm_unreachable("Unhandled instruction"); + } + + NodeInfo.NewValue = Res; + if (auto *ResI = dyn_cast<Instruction>(Res)) + ResI->takeName(I); + } + + Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy); + Type *DstTy = CurrentTruncInst->getType(); + if (Res->getType() != DstTy) { + IRBuilder<> Builder(CurrentTruncInst); + Res = Builder.CreateIntCast(Res, DstTy, false); + if (auto *ResI = dyn_cast<Instruction>(Res)) + ResI->takeName(CurrentTruncInst); + } + CurrentTruncInst->replaceAllUsesWith(Res); + + // Erase old expression dag, which was replaced by the reduced expression dag. + // We iterate backward, which means we visit the instruction before we visit + // any of its operands, this way, when we get to the operand, we already + // removed the instructions (from the expression dag) that uses it. + CurrentTruncInst->eraseFromParent(); + for (auto I = InstInfoMap.rbegin(), E = InstInfoMap.rend(); I != E; ++I) { + // We still need to check that the instruction has no users before we erase + // it, because {SExt, ZExt}Inst Instruction might have other users that was + // not reduced, in such case, we need to keep that instruction. + if (I->first->use_empty()) + I->first->eraseFromParent(); + } +} + +bool TruncInstCombine::run(Function &F) { + bool MadeIRChange = false; + + // Collect all TruncInst in the function into the Worklist for evaluating. + for (auto &BB : F) { + // Ignore unreachable basic block. + if (!DT.isReachableFromEntry(&BB)) + continue; + for (auto &I : BB) + if (auto *CI = dyn_cast<TruncInst>(&I)) + Worklist.push_back(CI); + } + + // Process all TruncInst in the Worklist, for each instruction: + // 1. Check if it dominates an eligible expression dag to be reduced. + // 2. Create a reduced expression dag and replace the old one with it. + while (!Worklist.empty()) { + CurrentTruncInst = Worklist.pop_back_val(); + + if (Type *NewDstSclTy = getBestTruncatedType()) { + LLVM_DEBUG( + dbgs() << "ICE: TruncInstCombine reducing type of expression dag " + "dominated by: " + << CurrentTruncInst << '\n'); + ReduceExpressionDag(NewDstSclTy); + MadeIRChange = true; + } + } + + return MadeIRChange; +} |
