diff options
Diffstat (limited to 'lib/Transforms/Scalar/Reassociate.cpp')
-rw-r--r-- | lib/Transforms/Scalar/Reassociate.cpp | 103 |
1 files changed, 72 insertions, 31 deletions
diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index cb893eab1654..fa8c9e2a5fe4 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -1,9 +1,8 @@ //===- Reassociate.cpp - Reassociate binary expressions -------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -267,12 +266,16 @@ static BinaryOperator *CreateNeg(Value *S1, const Twine &Name, /// Replace 0-X with X*-1. static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) { + assert((isa<UnaryOperator>(Neg) || isa<BinaryOperator>(Neg)) && + "Expected a Negate!"); + // FIXME: It's not safe to lower a unary FNeg into a FMul by -1.0. + unsigned OpNo = isa<BinaryOperator>(Neg) ? 1 : 0; Type *Ty = Neg->getType(); Constant *NegOne = Ty->isIntOrIntVectorTy() ? ConstantInt::getAllOnesValue(Ty) : ConstantFP::get(Ty, -1.0); - BinaryOperator *Res = CreateMul(Neg->getOperand(1), NegOne, "", Neg, Neg); - Neg->setOperand(1, Constant::getNullValue(Ty)); // Drop use of op. + BinaryOperator *Res = CreateMul(Neg->getOperand(OpNo), NegOne, "", Neg, Neg); + Neg->setOperand(OpNo, Constant::getNullValue(Ty)); // Drop use of op. Res->takeName(Neg); Neg->replaceAllUsesWith(Res); Res->setDebugLoc(Neg->getDebugLoc()); @@ -445,8 +448,10 @@ using RepeatedValue = std::pair<Value*, APInt>; /// that have all uses inside the expression (i.e. only used by non-leaf nodes /// of the expression) if it can turn them into binary operators of the right /// type and thus make the expression bigger. -static bool LinearizeExprTree(BinaryOperator *I, +static bool LinearizeExprTree(Instruction *I, SmallVectorImpl<RepeatedValue> &Ops) { + assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) && + "Expected a UnaryOperator or BinaryOperator!"); LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits(); unsigned Opcode = I->getOpcode(); @@ -463,7 +468,7 @@ static bool LinearizeExprTree(BinaryOperator *I, // with their weights, representing a certain number of paths to the operator. // If an operator occurs in the worklist multiple times then we found multiple // ways to get to it. - SmallVector<std::pair<BinaryOperator*, APInt>, 8> Worklist; // (Op, Weight) + SmallVector<std::pair<Instruction*, APInt>, 8> Worklist; // (Op, Weight) Worklist.push_back(std::make_pair(I, APInt(Bitwidth, 1))); bool Changed = false; @@ -490,10 +495,10 @@ static bool LinearizeExprTree(BinaryOperator *I, SmallPtrSet<Value *, 8> Visited; // For sanity checking the iteration scheme. #endif while (!Worklist.empty()) { - std::pair<BinaryOperator*, APInt> P = Worklist.pop_back_val(); + std::pair<Instruction*, APInt> P = Worklist.pop_back_val(); I = P.first; // We examine the operands of this binary operator. - for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) { // Visit operands. + for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands. Value *Op = I->getOperand(OpIdx); APInt Weight = P.second; // Number of paths to this operand. LLVM_DEBUG(dbgs() << "OPERAND: " << *Op << " (" << Weight << ")\n"); @@ -573,14 +578,14 @@ static bool LinearizeExprTree(BinaryOperator *I, // If this is a multiply expression, turn any internal negations into // multiplies by -1 so they can be reassociated. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) - if ((Opcode == Instruction::Mul && match(BO, m_Neg(m_Value()))) || - (Opcode == Instruction::FMul && match(BO, m_FNeg(m_Value())))) { + if (Instruction *Tmp = dyn_cast<Instruction>(Op)) + if ((Opcode == Instruction::Mul && match(Tmp, m_Neg(m_Value()))) || + (Opcode == Instruction::FMul && match(Tmp, m_FNeg(m_Value())))) { LLVM_DEBUG(dbgs() << "MORPH LEAF: " << *Op << " (" << Weight << ") TO "); - BO = LowerNegateToMultiply(BO); - LLVM_DEBUG(dbgs() << *BO << '\n'); - Worklist.push_back(std::make_pair(BO, Weight)); + Tmp = LowerNegateToMultiply(Tmp); + LLVM_DEBUG(dbgs() << *Tmp << '\n'); + Worklist.push_back(std::make_pair(Tmp, Weight)); Changed = true; continue; } @@ -862,6 +867,8 @@ static Value *NegateValue(Value *V, Instruction *BI, if (TheNeg->getParent()->getParent() != BI->getParent()->getParent()) continue; + bool FoundCatchSwitch = false; + BasicBlock::iterator InsertPt; if (Instruction *InstInput = dyn_cast<Instruction>(V)) { if (InvokeInst *II = dyn_cast<InvokeInst>(InstInput)) { @@ -869,10 +876,30 @@ static Value *NegateValue(Value *V, Instruction *BI, } else { InsertPt = ++InstInput->getIterator(); } - while (isa<PHINode>(InsertPt)) ++InsertPt; + + const BasicBlock *BB = InsertPt->getParent(); + + // Make sure we don't move anything before PHIs or exception + // handling pads. + while (InsertPt != BB->end() && (isa<PHINode>(InsertPt) || + InsertPt->isEHPad())) { + if (isa<CatchSwitchInst>(InsertPt)) + // A catchswitch cannot have anything in the block except + // itself and PHIs. We'll bail out below. + FoundCatchSwitch = true; + ++InsertPt; + } } else { InsertPt = TheNeg->getParent()->getParent()->getEntryBlock().begin(); } + + // We found a catchswitch in the block where we want to move the + // neg. We cannot move anything into that block. Bail and just + // create the neg before BI, as if we hadn't found an existing + // neg. + if (FoundCatchSwitch) + break; + TheNeg->moveBefore(&*InsertPt); if (TheNeg->getOpcode() == Instruction::Sub) { TheNeg->setHasNoUnsignedWrap(false); @@ -1329,8 +1356,7 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, // So, if Rank(X) < Rank(Y) < Rank(Z), it means X is defined earlier // than Y which is defined earlier than Z. Permute "x | 1", "Y & 2", // "z" in the order of X-Y-Z is better than any other orders. - std::stable_sort(OpndPtrs.begin(), OpndPtrs.end(), - [](XorOpnd *LHS, XorOpnd *RHS) { + llvm::stable_sort(OpndPtrs, [](XorOpnd *LHS, XorOpnd *RHS) { return LHS->getSymbolicRank() < RHS->getSymbolicRank(); }); @@ -1687,8 +1713,7 @@ static bool collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops, // below our mininum of '4'. assert(FactorPowerSum >= 4); - std::stable_sort(Factors.begin(), Factors.end(), - [](const Factor &LHS, const Factor &RHS) { + llvm::stable_sort(Factors, [](const Factor &LHS, const Factor &RHS) { return LHS.Power > RHS.Power; }); return true; @@ -1801,7 +1826,7 @@ Value *ReassociatePass::OptimizeMul(BinaryOperator *I, return V; ValueEntry NewEntry = ValueEntry(getRank(V), V); - Ops.insert(std::lower_bound(Ops.begin(), Ops.end(), NewEntry), NewEntry); + Ops.insert(llvm::lower_bound(Ops, NewEntry), NewEntry); return nullptr; } @@ -2001,7 +2026,7 @@ Instruction *ReassociatePass::canonicalizeNegConstExpr(Instruction *I) { /// instructions is not allowed. void ReassociatePass::OptimizeInst(Instruction *I) { // Only consider operations that we understand. - if (!isa<BinaryOperator>(I)) + if (!isa<UnaryOperator>(I) && !isa<BinaryOperator>(I)) return; if (I->getOpcode() == Instruction::Shl && isa<ConstantInt>(I->getOperand(1))) @@ -2066,7 +2091,8 @@ void ReassociatePass::OptimizeInst(Instruction *I) { I = NI; } } - } else if (I->getOpcode() == Instruction::FSub) { + } else if (I->getOpcode() == Instruction::FNeg || + I->getOpcode() == Instruction::FSub) { if (ShouldBreakUpSubtract(I)) { Instruction *NI = BreakUpSubtract(I, RedoInsts); RedoInsts.insert(I); @@ -2075,7 +2101,9 @@ void ReassociatePass::OptimizeInst(Instruction *I) { } else if (match(I, m_FNeg(m_Value()))) { // Otherwise, this is a negation. See if the operand is a multiply tree // and if this is not an inner node of a multiply tree. - if (isReassociableOp(I->getOperand(1), Instruction::FMul) && + Value *Op = isa<BinaryOperator>(I) ? I->getOperand(1) : + I->getOperand(0); + if (isReassociableOp(Op, Instruction::FMul) && (!I->hasOneUse() || !isReassociableOp(I->user_back(), Instruction::FMul))) { // If the negate was simplified, revisit the users to see if we can @@ -2142,7 +2170,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { // positions maintained (and so the compiler is deterministic). Note that // this sorts so that the highest ranking values end up at the beginning of // the vector. - std::stable_sort(Ops.begin(), Ops.end()); + llvm::stable_sort(Ops); // Now that we have the expression tree in a convenient // sorted form, optimize it globally if possible. @@ -2218,8 +2246,15 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { if (std::less<Value *>()(Op1, Op0)) std::swap(Op0, Op1); auto it = PairMap[Idx].find({Op0, Op1}); - if (it != PairMap[Idx].end()) - Score += it->second; + if (it != PairMap[Idx].end()) { + // Functions like BreakUpSubtract() can erase the Values we're using + // as keys and create new Values after we built the PairMap. There's a + // small chance that the new nodes can have the same address as + // something already in the table. We shouldn't accumulate the stored + // score in that case as it refers to the wrong Value. + if (it->second.isValid()) + Score += it->second.Score; + } unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank); if (Score > Max || (Score == Max && MaxRank < BestRank)) { @@ -2288,9 +2323,15 @@ ReassociatePass::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) { std::swap(Op0, Op1); if (!Visited.insert({Op0, Op1}).second) continue; - auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, 1}); - if (!res.second) - ++res.first->second; + auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, {Op0, Op1, 1}}); + if (!res.second) { + // If either key value has been erased then we've got the same + // address by coincidence. That can't happen here because nothing is + // erasing values but it can happen by the time we're querying the + // map. + assert(res.first->second.isValid() && "WeakVH invalidated"); + ++res.first->second.Score; + } } } } |