diff options
Diffstat (limited to 'lib/Transforms/Scalar/Reassociate.cpp')
-rw-r--r-- | lib/Transforms/Scalar/Reassociate.cpp | 220 |
1 files changed, 167 insertions, 53 deletions
diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index e235e5eb1a06..88dcaf0f8a36 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -21,28 +21,45 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/Reassociate.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> +#include <cassert> +#include <utility> + using namespace llvm; using namespace reassociate; @@ -54,7 +71,6 @@ STATISTIC(NumFactor , "Number of multiplies factored"); #ifndef NDEBUG /// Print out the expression identified in the Ops list. -/// static void PrintOps(Instruction *I, const SmallVectorImpl<ValueEntry> &Ops) { Module *M = I->getModule(); dbgs() << Instruction::getOpcodeName(I->getOpcode()) << " " @@ -128,38 +144,37 @@ XorOpnd::XorOpnd(Value *V) { /// Return true if V is an instruction of the specified opcode and if it /// only has one use. static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode) { - if (V->hasOneUse() && isa<Instruction>(V) && - cast<Instruction>(V)->getOpcode() == Opcode && - (!isa<FPMathOperator>(V) || - cast<Instruction>(V)->hasUnsafeAlgebra())) - return cast<BinaryOperator>(V); + auto *I = dyn_cast<Instruction>(V); + if (I && I->hasOneUse() && I->getOpcode() == Opcode) + if (!isa<FPMathOperator>(I) || I->isFast()) + return cast<BinaryOperator>(I); return nullptr; } static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode1, unsigned Opcode2) { - if (V->hasOneUse() && isa<Instruction>(V) && - (cast<Instruction>(V)->getOpcode() == Opcode1 || - cast<Instruction>(V)->getOpcode() == Opcode2) && - (!isa<FPMathOperator>(V) || - cast<Instruction>(V)->hasUnsafeAlgebra())) - return cast<BinaryOperator>(V); + auto *I = dyn_cast<Instruction>(V); + if (I && I->hasOneUse() && + (I->getOpcode() == Opcode1 || I->getOpcode() == Opcode2)) + if (!isa<FPMathOperator>(I) || I->isFast()) + return cast<BinaryOperator>(I); return nullptr; } void ReassociatePass::BuildRankMap(Function &F, ReversePostOrderTraversal<Function*> &RPOT) { - unsigned i = 2; + unsigned Rank = 2; // Assign distinct ranks to function arguments. - for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) { - ValueRankMap[&*I] = ++i; - DEBUG(dbgs() << "Calculated Rank[" << I->getName() << "] = " << i << "\n"); + for (auto &Arg : F.args()) { + ValueRankMap[&Arg] = ++Rank; + DEBUG(dbgs() << "Calculated Rank[" << Arg.getName() << "] = " << Rank + << "\n"); } // Traverse basic blocks in ReversePostOrder for (BasicBlock *BB : RPOT) { - unsigned BBRank = RankMap[BB] = ++i << 16; + unsigned BBRank = RankMap[BB] = ++Rank << 16; // Walk the basic block, adding precomputed ranks for any instructions that // we cannot move. This ensures that the ranks for these instructions are @@ -207,13 +222,9 @@ void ReassociatePass::canonicalizeOperands(Instruction *I) { Value *LHS = I->getOperand(0); Value *RHS = I->getOperand(1); - unsigned LHSRank = getRank(LHS); - unsigned RHSRank = getRank(RHS); - - if (isa<Constant>(RHS)) + if (LHS == RHS || isa<Constant>(RHS)) return; - - if (isa<Constant>(LHS) || RHSRank < LHSRank) + if (isa<Constant>(LHS) || getRank(RHS) < getRank(LHS)) cast<BinaryOperator>(I)->swapOperands(); } @@ -357,7 +368,7 @@ static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) { } } -typedef std::pair<Value*, APInt> RepeatedValue; +using RepeatedValue = std::pair<Value*, APInt>; /// Given an associative binary expression, return the leaf /// nodes in Ops along with their weights (how many times the leaf occurs). The @@ -432,7 +443,6 @@ typedef std::pair<Value*, APInt> RepeatedValue; /// that have all uses inside the expression (i.e. only used by non-leaf nodes /// of the expression) if it can turn them into binary operators of the right /// type and thus make the expression bigger. - static bool LinearizeExprTree(BinaryOperator *I, SmallVectorImpl<RepeatedValue> &Ops) { DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); @@ -470,12 +480,12 @@ static bool LinearizeExprTree(BinaryOperator *I, // Leaves - Keeps track of the set of putative leaves as well as the number of // paths to each leaf seen so far. - typedef DenseMap<Value*, APInt> LeafMap; + using LeafMap = DenseMap<Value *, APInt>; LeafMap Leaves; // Leaf -> Total weight so far. - SmallVector<Value*, 8> LeafOrder; // Ensure deterministic leaf output order. + SmallVector<Value *, 8> LeafOrder; // Ensure deterministic leaf output order. #ifndef NDEBUG - SmallPtrSet<Value*, 8> Visited; // For sanity checking the iteration scheme. + SmallPtrSet<Value *, 8> Visited; // For sanity checking the iteration scheme. #endif while (!Worklist.empty()) { std::pair<BinaryOperator*, APInt> P = Worklist.pop_back_val(); @@ -554,7 +564,7 @@ static bool LinearizeExprTree(BinaryOperator *I, assert((!isa<Instruction>(Op) || cast<Instruction>(Op)->getOpcode() != Opcode || (isa<FPMathOperator>(Op) && - !cast<Instruction>(Op)->hasUnsafeAlgebra())) && + !cast<Instruction>(Op)->isFast())) && "Should have been handled above!"); assert(Op->hasOneUse() && "Has uses outside the expression tree!"); @@ -773,7 +783,7 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, break; ExpressionChanged->moveBefore(I); ExpressionChanged = cast<BinaryOperator>(*ExpressionChanged->user_begin()); - } while (1); + } while (true); // Throw away any left over nodes from the original expression. for (unsigned i = 0, e = NodesToRewrite.size(); i != e; ++i) @@ -789,13 +799,9 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, /// additional opportunities have been exposed. static Value *NegateValue(Value *V, Instruction *BI, SetVector<AssertingVH<Instruction>> &ToRedo) { - if (Constant *C = dyn_cast<Constant>(V)) { - if (C->getType()->isFPOrFPVectorTy()) { - return ConstantExpr::getFNeg(C); - } - return ConstantExpr::getNeg(C); - } - + if (auto *C = dyn_cast<Constant>(V)) + return C->getType()->isFPOrFPVectorTy() ? ConstantExpr::getFNeg(C) : + ConstantExpr::getNeg(C); // We are trying to expose opportunity for reassociation. One of the things // that we want to do to achieve this is to push a negation as deep into an @@ -913,7 +919,6 @@ BreakUpSubtract(Instruction *Sub, SetVector<AssertingVH<Instruction>> &ToRedo) { // // Calculate the negative value of Operand 1 of the sub instruction, // and set it as the RHS of the add instruction we just made. - // Value *NegVal = NegateValue(Sub->getOperand(1), Sub, ToRedo); BinaryOperator *New = CreateAdd(Sub->getOperand(0), NegVal, "", Sub, Sub); Sub->setOperand(0, Constant::getNullValue(Sub->getType())); // Drop use of op. @@ -990,7 +995,7 @@ static Value *EmitAddTreeOfValues(Instruction *I, Value *V1 = Ops.back(); Ops.pop_back(); Value *V2 = EmitAddTreeOfValues(I, Ops); - return CreateAdd(V2, V1, "tmp", I, I); + return CreateAdd(V2, V1, "reass.add", I, I); } /// If V is an expression tree that is a multiplication sequence, @@ -1157,7 +1162,6 @@ static Value *createAndInstr(Instruction *InsertBefore, Value *Opnd, // If it was successful, true is returned, and the "R" and "C" is returned // via "Res" and "ConstOpnd", respectively; otherwise, false is returned, // and both "Res" and "ConstOpnd" remain unchanged. -// bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, APInt &ConstOpnd, Value *&Res) { // Xor-Rule 1: (x | c1) ^ c2 = (x | c1) ^ (c1 ^ c1) ^ c2 @@ -1183,7 +1187,6 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, RedoInsts.insert(T); return true; } - // Helper function of OptimizeXor(). It tries to simplify // "Opnd1 ^ Opnd2 ^ ConstOpnd" into "R ^ C", where C would be 0, and R is a @@ -1230,7 +1233,6 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, Res = createAndInstr(I, X, C3); ConstOpnd ^= C1; - } else if (Opnd1->isOrExpr()) { // Xor-Rule 3: (x | c1) ^ (x | c2) = (x & c3) ^ c3 where c3 = c1 ^ c2 // @@ -1349,7 +1351,6 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, // step 3.2: When previous and current operands share the same symbolic // value, try to simplify "PrevOpnd ^ CurrOpnd ^ ConstOpnd" - // if (CombineXorOpnd(I, CurrOpnd, PrevOpnd, ConstOpnd, CV)) { // Remove previous operand PrevOpnd->Invalidate(); @@ -1601,7 +1602,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, RedoInsts.insert(VI); // Create the multiply. - Instruction *V2 = CreateMul(V, MaxOccVal, "tmp", I, I); + Instruction *V2 = CreateMul(V, MaxOccVal, "reass.mul", I, I); // Rerun associate on the multiply in case the inner expression turned into // a multiply. We want to make sure that we keep things in canonical form. @@ -2012,8 +2013,8 @@ void ReassociatePass::OptimizeInst(Instruction *I) { if (I->isCommutative()) canonicalizeOperands(I); - // Don't optimize floating point instructions that don't have unsafe algebra. - if (I->getType()->isFPOrFPVectorTy() && !I->hasUnsafeAlgebra()) + // Don't optimize floating-point instructions unless they are 'fast'. + if (I->getType()->isFPOrFPVectorTy() && !I->isFast()) return; // Do not reassociate boolean (i1) expressions. We want to preserve the @@ -2140,7 +2141,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { DEBUG(dbgs() << "Reassoc to scalar: " << *V << '\n'); I->replaceAllUsesWith(V); if (Instruction *VI = dyn_cast<Instruction>(V)) - VI->setDebugLoc(I->getDebugLoc()); + if (I->getDebugLoc()) + VI->setDebugLoc(I->getDebugLoc()); RedoInsts.insert(I); ++NumAnnihil; return; @@ -2183,11 +2185,104 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { return; } + if (Ops.size() > 2 && Ops.size() <= GlobalReassociateLimit) { + // Find the pair with the highest count in the pairmap and move it to the + // back of the list so that it can later be CSE'd. + // example: + // a*b*c*d*e + // if c*e is the most "popular" pair, we can express this as + // (((c*e)*d)*b)*a + unsigned Max = 1; + unsigned BestRank = 0; + std::pair<unsigned, unsigned> BestPair; + unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin; + for (unsigned i = 0; i < Ops.size() - 1; ++i) + for (unsigned j = i + 1; j < Ops.size(); ++j) { + unsigned Score = 0; + Value *Op0 = Ops[i].Op; + Value *Op1 = Ops[j].Op; + if (std::less<Value *>()(Op1, Op0)) + std::swap(Op0, Op1); + auto it = PairMap[Idx].find({Op0, Op1}); + if (it != PairMap[Idx].end()) + Score += it->second; + + unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank); + if (Score > Max || (Score == Max && MaxRank < BestRank)) { + BestPair = {i, j}; + Max = Score; + BestRank = MaxRank; + } + } + if (Max > 1) { + auto Op0 = Ops[BestPair.first]; + auto Op1 = Ops[BestPair.second]; + Ops.erase(&Ops[BestPair.second]); + Ops.erase(&Ops[BestPair.first]); + Ops.push_back(Op0); + Ops.push_back(Op1); + } + } // Now that we ordered and optimized the expressions, splat them back into // the expression tree, removing any unneeded nodes. RewriteExprTree(I, Ops); } +void +ReassociatePass::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) { + // Make a "pairmap" of how often each operand pair occurs. + for (BasicBlock *BI : RPOT) { + for (Instruction &I : *BI) { + if (!I.isAssociative()) + continue; + + // Ignore nodes that aren't at the root of trees. + if (I.hasOneUse() && I.user_back()->getOpcode() == I.getOpcode()) + continue; + + // Collect all operands in a single reassociable expression. + // Since Reassociate has already been run once, we can assume things + // are already canonical according to Reassociation's regime. + SmallVector<Value *, 8> Worklist = { I.getOperand(0), I.getOperand(1) }; + SmallVector<Value *, 8> Ops; + while (!Worklist.empty() && Ops.size() <= GlobalReassociateLimit) { + Value *Op = Worklist.pop_back_val(); + Instruction *OpI = dyn_cast<Instruction>(Op); + if (!OpI || OpI->getOpcode() != I.getOpcode() || !OpI->hasOneUse()) { + Ops.push_back(Op); + continue; + } + // Be paranoid about self-referencing expressions in unreachable code. + if (OpI->getOperand(0) != OpI) + Worklist.push_back(OpI->getOperand(0)); + if (OpI->getOperand(1) != OpI) + Worklist.push_back(OpI->getOperand(1)); + } + // Skip extremely long expressions. + if (Ops.size() > GlobalReassociateLimit) + continue; + + // Add all pairwise combinations of operands to the pair map. + unsigned BinaryIdx = I.getOpcode() - Instruction::BinaryOpsBegin; + SmallSet<std::pair<Value *, Value*>, 32> Visited; + for (unsigned i = 0; i < Ops.size() - 1; ++i) { + for (unsigned j = i + 1; j < Ops.size(); ++j) { + // Canonicalize operand orderings. + Value *Op0 = Ops[i]; + Value *Op1 = Ops[j]; + if (std::less<Value *>()(Op1, Op0)) + std::swap(Op0, Op1); + if (!Visited.insert({Op0, Op1}).second) + continue; + auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, 1}); + if (!res.second) + ++res.first->second; + } + } + } + } +} + PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { // Get the functions basic blocks in Reverse Post Order. This order is used by // BuildRankMap to pre calculate ranks correctly. It also excludes dead basic @@ -2198,8 +2293,20 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { // Calculate the rank map for F. BuildRankMap(F, RPOT); + // Build the pair map before running reassociate. + // Technically this would be more accurate if we did it after one round + // of reassociation, but in practice it doesn't seem to help much on + // real-world code, so don't waste the compile time running reassociate + // twice. + // If a user wants, they could expicitly run reassociate twice in their + // pass pipeline for further potential gains. + // It might also be possible to update the pair map during runtime, but the + // overhead of that may be large if there's many reassociable chains. + BuildPairMap(RPOT); + MadeChange = false; - // Traverse the same blocks that was analysed by BuildRankMap. + + // Traverse the same blocks that were analysed by BuildRankMap. for (BasicBlock *BI : RPOT) { assert(RankMap.count(&*BI) && "BB should be ranked."); // Optimize every instruction in the basic block. @@ -2238,9 +2345,11 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { } } - // We are done with the rank map. + // We are done with the rank map and pair map. RankMap.clear(); ValueRankMap.clear(); + for (auto &Entry : PairMap) + Entry.clear(); if (MadeChange) { PreservedAnalyses PA; @@ -2253,10 +2362,13 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { } namespace { + class ReassociateLegacyPass : public FunctionPass { ReassociatePass Impl; + public: static char ID; // Pass identification, replacement for typeid + ReassociateLegacyPass() : FunctionPass(ID) { initializeReassociateLegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -2275,9 +2387,11 @@ namespace { AU.addPreserved<GlobalsAAWrapperPass>(); } }; -} + +} // end anonymous namespace char ReassociateLegacyPass::ID = 0; + INITIALIZE_PASS(ReassociateLegacyPass, "reassociate", "Reassociate expressions", false, false) |