diff options
Diffstat (limited to 'include/llvm/Transforms/Scalar/GVNExpression.h')
-rw-r--r-- | include/llvm/Transforms/Scalar/GVNExpression.h | 313 |
1 files changed, 171 insertions, 142 deletions
diff --git a/include/llvm/Transforms/Scalar/GVNExpression.h b/include/llvm/Transforms/Scalar/GVNExpression.h index 3458696e0687a..2670a0c1a5339 100644 --- a/include/llvm/Transforms/Scalar/GVNExpression.h +++ b/include/llvm/Transforms/Scalar/GVNExpression.h @@ -1,4 +1,4 @@ -//======- GVNExpression.h - GVN Expression classes -------*- C++ -*-==-------=// +//======- GVNExpression.h - GVN Expression classes --------------*- C++ -*-===// // // The LLVM Compiler Infrastructure // @@ -17,18 +17,22 @@ #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/MemorySSA.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Value.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/ArrayRecycler.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils/MemorySSA.h" #include <algorithm> +#include <cassert> +#include <iterator> +#include <utility> namespace llvm { -class MemoryAccess; namespace GVNExpression { @@ -39,11 +43,13 @@ enum ExpressionType { ET_Unknown, ET_BasicStart, ET_Basic, - ET_Call, ET_AggregateValue, ET_Phi, + ET_MemoryStart, + ET_Call, ET_Load, ET_Store, + ET_MemoryEnd, ET_BasicEnd }; @@ -53,23 +59,22 @@ private: unsigned Opcode; public: - Expression(const Expression &) = delete; Expression(ExpressionType ET = ET_Base, unsigned O = ~2U) : EType(ET), Opcode(O) {} - void operator=(const Expression &) = delete; + Expression(const Expression &) = delete; + Expression &operator=(const Expression &) = delete; virtual ~Expression(); static unsigned getEmptyKey() { return ~0U; } static unsigned getTombstoneKey() { return ~1U; } - + bool operator!=(const Expression &Other) const { return !(*this == Other); } bool operator==(const Expression &Other) const { if (getOpcode() != Other.getOpcode()) return false; if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey()) return true; // Compare the expression type for anything but load and store. - // For load and store we set the opcode to zero. - // This is needed for load coercion. + // For load and store we set the opcode to zero to make them equal. if (getExpressionType() != ET_Load && getExpressionType() != ET_Store && getExpressionType() != Other.getExpressionType()) return false; @@ -83,9 +88,8 @@ public: void setOpcode(unsigned opcode) { Opcode = opcode; } ExpressionType getExpressionType() const { return EType; } - virtual hash_code getHashValue() const { - return hash_combine(getExpressionType(), getOpcode()); - } + // We deliberately leave the expression type out of the hash value. + virtual hash_code getHashValue() const { return getOpcode(); } // // Debugging support @@ -101,7 +105,11 @@ public: printInternal(OS, true); OS << "}"; } - void dump() const { print(dbgs()); } + + LLVM_DUMP_METHOD void dump() const { + print(dbgs()); + dbgs() << "\n"; + } }; inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) { @@ -119,20 +127,20 @@ private: Type *ValueType; public: - static bool classof(const Expression *EB) { - ExpressionType ET = EB->getExpressionType(); - return ET > ET_BasicStart && ET < ET_BasicEnd; - } - BasicExpression(unsigned NumOperands) : BasicExpression(NumOperands, ET_Basic) {} BasicExpression(unsigned NumOperands, ExpressionType ET) : Expression(ET), Operands(nullptr), MaxOperands(NumOperands), NumOperands(0), ValueType(nullptr) {} - virtual ~BasicExpression() override; - void operator=(const BasicExpression &) = delete; - BasicExpression(const BasicExpression &) = delete; BasicExpression() = delete; + BasicExpression(const BasicExpression &) = delete; + BasicExpression &operator=(const BasicExpression &) = delete; + ~BasicExpression() override; + + static bool classof(const Expression *EB) { + ExpressionType ET = EB->getExpressionType(); + return ET > ET_BasicStart && ET < ET_BasicEnd; + } /// \brief Swap two operands. Used during GVN to put commutative operands in /// order. @@ -185,7 +193,7 @@ public: void setType(Type *T) { ValueType = T; } Type *getType() const { return ValueType; } - virtual bool equals(const Expression &Other) const override { + bool equals(const Expression &Other) const override { if (getOpcode() != Other.getOpcode()) return false; @@ -194,15 +202,15 @@ public: std::equal(op_begin(), op_end(), OE.op_begin()); } - virtual hash_code getHashValue() const override { - return hash_combine(getExpressionType(), getOpcode(), ValueType, + hash_code getHashValue() const override { + return hash_combine(this->Expression::getHashValue(), ValueType, hash_combine_range(op_begin(), op_end())); } // // Debugging support // - virtual void printInternal(raw_ostream &OS, bool PrintEType) const override { + void printInternal(raw_ostream &OS, bool PrintEType) const override { if (PrintEType) OS << "ExpressionTypeBasic, "; @@ -216,6 +224,7 @@ public: OS << "} "; } }; + class op_inserter : public std::iterator<std::output_iterator_tag, void, void, void, void> { private: @@ -235,131 +244,143 @@ public: op_inserter &operator++(int) { return *this; } }; -class CallExpression final : public BasicExpression { +class MemoryExpression : public BasicExpression { private: - CallInst *Call; - MemoryAccess *DefiningAccess; + const MemoryAccess *MemoryLeader; public: + MemoryExpression(unsigned NumOperands, enum ExpressionType EType, + const MemoryAccess *MemoryLeader) + : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader){}; + + MemoryExpression() = delete; + MemoryExpression(const MemoryExpression &) = delete; + MemoryExpression &operator=(const MemoryExpression &) = delete; static bool classof(const Expression *EB) { - return EB->getExpressionType() == ET_Call; + return EB->getExpressionType() > ET_MemoryStart && + EB->getExpressionType() < ET_MemoryEnd; + } + hash_code getHashValue() const override { + return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader); } - CallExpression(unsigned NumOperands, CallInst *C, MemoryAccess *DA) - : BasicExpression(NumOperands, ET_Call), Call(C), DefiningAccess(DA) {} - void operator=(const CallExpression &) = delete; - CallExpression(const CallExpression &) = delete; - CallExpression() = delete; - virtual ~CallExpression() override; - - virtual bool equals(const Expression &Other) const override { + bool equals(const Expression &Other) const override { if (!this->BasicExpression::equals(Other)) return false; - const auto &OE = cast<CallExpression>(Other); - return DefiningAccess == OE.DefiningAccess; + const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other); + + return MemoryLeader == OtherMCE.MemoryLeader; } - virtual hash_code getHashValue() const override { - return hash_combine(this->BasicExpression::getHashValue(), DefiningAccess); + const MemoryAccess *getMemoryLeader() const { return MemoryLeader; } + void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; } +}; + +class CallExpression final : public MemoryExpression { +private: + CallInst *Call; + +public: + CallExpression(unsigned NumOperands, CallInst *C, + const MemoryAccess *MemoryLeader) + : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {} + CallExpression() = delete; + CallExpression(const CallExpression &) = delete; + CallExpression &operator=(const CallExpression &) = delete; + ~CallExpression() override; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_Call; } // // Debugging support // - virtual void printInternal(raw_ostream &OS, bool PrintEType) const override { + void printInternal(raw_ostream &OS, bool PrintEType) const override { if (PrintEType) OS << "ExpressionTypeCall, "; this->BasicExpression::printInternal(OS, false); - OS << " represents call at " << Call; + OS << " represents call at "; + Call->printAsOperand(OS); } }; -class LoadExpression final : public BasicExpression { +class LoadExpression final : public MemoryExpression { private: LoadInst *Load; - MemoryAccess *DefiningAccess; unsigned Alignment; public: - static bool classof(const Expression *EB) { - return EB->getExpressionType() == ET_Load; - } - - LoadExpression(unsigned NumOperands, LoadInst *L, MemoryAccess *DA) - : LoadExpression(ET_Load, NumOperands, L, DA) {} + LoadExpression(unsigned NumOperands, LoadInst *L, + const MemoryAccess *MemoryLeader) + : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {} LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L, - MemoryAccess *DA) - : BasicExpression(NumOperands, EType), Load(L), DefiningAccess(DA) { + const MemoryAccess *MemoryLeader) + : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) { Alignment = L ? L->getAlignment() : 0; } - void operator=(const LoadExpression &) = delete; - LoadExpression(const LoadExpression &) = delete; LoadExpression() = delete; - virtual ~LoadExpression() override; + LoadExpression(const LoadExpression &) = delete; + LoadExpression &operator=(const LoadExpression &) = delete; + ~LoadExpression() override; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_Load; + } LoadInst *getLoadInst() const { return Load; } void setLoadInst(LoadInst *L) { Load = L; } - MemoryAccess *getDefiningAccess() const { return DefiningAccess; } - void setDefiningAccess(MemoryAccess *MA) { DefiningAccess = MA; } unsigned getAlignment() const { return Alignment; } void setAlignment(unsigned Align) { Alignment = Align; } - virtual bool equals(const Expression &Other) const override; - - virtual hash_code getHashValue() const override { - return hash_combine(getOpcode(), getType(), DefiningAccess, - hash_combine_range(op_begin(), op_end())); - } + bool equals(const Expression &Other) const override; // // Debugging support // - virtual void printInternal(raw_ostream &OS, bool PrintEType) const override { + void printInternal(raw_ostream &OS, bool PrintEType) const override { if (PrintEType) OS << "ExpressionTypeLoad, "; this->BasicExpression::printInternal(OS, false); - OS << " represents Load at " << Load; - OS << " with DefiningAccess " << *DefiningAccess; + OS << " represents Load at "; + Load->printAsOperand(OS); + OS << " with MemoryLeader " << *getMemoryLeader(); } }; -class StoreExpression final : public BasicExpression { +class StoreExpression final : public MemoryExpression { private: StoreInst *Store; - MemoryAccess *DefiningAccess; + Value *StoredValue; public: + StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue, + const MemoryAccess *MemoryLeader) + : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S), + StoredValue(StoredValue) {} + StoreExpression() = delete; + StoreExpression(const StoreExpression &) = delete; + StoreExpression &operator=(const StoreExpression &) = delete; + ~StoreExpression() override; + static bool classof(const Expression *EB) { return EB->getExpressionType() == ET_Store; } - StoreExpression(unsigned NumOperands, StoreInst *S, MemoryAccess *DA) - : BasicExpression(NumOperands, ET_Store), Store(S), DefiningAccess(DA) {} - void operator=(const StoreExpression &) = delete; - StoreExpression(const StoreExpression &) = delete; - StoreExpression() = delete; - virtual ~StoreExpression() override; - StoreInst *getStoreInst() const { return Store; } - MemoryAccess *getDefiningAccess() const { return DefiningAccess; } + Value *getStoredValue() const { return StoredValue; } - virtual bool equals(const Expression &Other) const override; + bool equals(const Expression &Other) const override; - virtual hash_code getHashValue() const override { - return hash_combine(getOpcode(), getType(), DefiningAccess, - hash_combine_range(op_begin(), op_end())); - } - - // // Debugging support // - virtual void printInternal(raw_ostream &OS, bool PrintEType) const override { + void printInternal(raw_ostream &OS, bool PrintEType) const override { if (PrintEType) OS << "ExpressionTypeStore, "; this->BasicExpression::printInternal(OS, false); - OS << " represents Store at " << Store; - OS << " with DefiningAccess " << *DefiningAccess; + OS << " represents Store " << *Store; + OS << " with MemoryLeader " << *getMemoryLeader(); } }; @@ -370,19 +391,19 @@ private: unsigned *IntOperands; public: - static bool classof(const Expression *EB) { - return EB->getExpressionType() == ET_AggregateValue; - } - AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands) : BasicExpression(NumOperands, ET_AggregateValue), MaxIntOperands(NumIntOperands), NumIntOperands(0), IntOperands(nullptr) {} - - void operator=(const AggregateValueExpression &) = delete; - AggregateValueExpression(const AggregateValueExpression &) = delete; AggregateValueExpression() = delete; - virtual ~AggregateValueExpression() override; + AggregateValueExpression(const AggregateValueExpression &) = delete; + AggregateValueExpression & + operator=(const AggregateValueExpression &) = delete; + ~AggregateValueExpression() override; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_AggregateValue; + } typedef unsigned *int_arg_iterator; typedef const unsigned *const_int_arg_iterator; @@ -407,7 +428,7 @@ public: IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands); } - virtual bool equals(const Expression &Other) const override { + bool equals(const Expression &Other) const override { if (!this->BasicExpression::equals(Other)) return false; const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other); @@ -415,7 +436,7 @@ public: std::equal(int_op_begin(), int_op_end(), OE.int_op_begin()); } - virtual hash_code getHashValue() const override { + hash_code getHashValue() const override { return hash_combine(this->BasicExpression::getHashValue(), hash_combine_range(int_op_begin(), int_op_end())); } @@ -423,7 +444,7 @@ public: // // Debugging support // - virtual void printInternal(raw_ostream &OS, bool PrintEType) const override { + void printInternal(raw_ostream &OS, bool PrintEType) const override { if (PrintEType) OS << "ExpressionTypeAggregateValue, "; this->BasicExpression::printInternal(OS, false); @@ -434,6 +455,7 @@ public: OS << "}"; } }; + class int_op_inserter : public std::iterator<std::output_iterator_tag, void, void, void, void> { private: @@ -443,6 +465,7 @@ private: public: explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {} explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {} + int_op_inserter &operator=(unsigned int val) { AVE->int_op_push_back(val); return *this; @@ -457,32 +480,32 @@ private: BasicBlock *BB; public: - static bool classof(const Expression *EB) { - return EB->getExpressionType() == ET_Phi; - } - PHIExpression(unsigned NumOperands, BasicBlock *B) : BasicExpression(NumOperands, ET_Phi), BB(B) {} - void operator=(const PHIExpression &) = delete; - PHIExpression(const PHIExpression &) = delete; PHIExpression() = delete; - virtual ~PHIExpression() override; + PHIExpression(const PHIExpression &) = delete; + PHIExpression &operator=(const PHIExpression &) = delete; + ~PHIExpression() override; - virtual bool equals(const Expression &Other) const override { + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_Phi; + } + + bool equals(const Expression &Other) const override { if (!this->BasicExpression::equals(Other)) return false; const PHIExpression &OE = cast<PHIExpression>(Other); return BB == OE.BB; } - virtual hash_code getHashValue() const override { + hash_code getHashValue() const override { return hash_combine(this->BasicExpression::getHashValue(), BB); } // // Debugging support // - virtual void printInternal(raw_ostream &OS, bool PrintEType) const override { + void printInternal(raw_ostream &OS, bool PrintEType) const override { if (PrintEType) OS << "ExpressionTypePhi, "; this->BasicExpression::printInternal(OS, false); @@ -495,31 +518,32 @@ private: Value *VariableValue; public: + VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {} + VariableExpression() = delete; + VariableExpression(const VariableExpression &) = delete; + VariableExpression &operator=(const VariableExpression &) = delete; + static bool classof(const Expression *EB) { return EB->getExpressionType() == ET_Variable; } - VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {} - void operator=(const VariableExpression &) = delete; - VariableExpression(const VariableExpression &) = delete; - VariableExpression() = delete; - Value *getVariableValue() const { return VariableValue; } void setVariableValue(Value *V) { VariableValue = V; } - virtual bool equals(const Expression &Other) const override { + + bool equals(const Expression &Other) const override { const VariableExpression &OC = cast<VariableExpression>(Other); return VariableValue == OC.VariableValue; } - virtual hash_code getHashValue() const override { - return hash_combine(getExpressionType(), VariableValue->getType(), - VariableValue); + hash_code getHashValue() const override { + return hash_combine(this->Expression::getHashValue(), + VariableValue->getType(), VariableValue); } // // Debugging support // - virtual void printInternal(raw_ostream &OS, bool PrintEType) const override { + void printInternal(raw_ostream &OS, bool PrintEType) const override { if (PrintEType) OS << "ExpressionTypeVariable, "; this->Expression::printInternal(OS, false); @@ -529,36 +553,36 @@ public: class ConstantExpression final : public Expression { private: - Constant *ConstantValue; + Constant *ConstantValue = nullptr; public: - static bool classof(const Expression *EB) { - return EB->getExpressionType() == ET_Constant; - } - - ConstantExpression() : Expression(ET_Constant), ConstantValue(NULL) {} + ConstantExpression() : Expression(ET_Constant) {} ConstantExpression(Constant *constantValue) : Expression(ET_Constant), ConstantValue(constantValue) {} - void operator=(const ConstantExpression &) = delete; ConstantExpression(const ConstantExpression &) = delete; + ConstantExpression &operator=(const ConstantExpression &) = delete; + + static bool classof(const Expression *EB) { + return EB->getExpressionType() == ET_Constant; + } Constant *getConstantValue() const { return ConstantValue; } void setConstantValue(Constant *V) { ConstantValue = V; } - virtual bool equals(const Expression &Other) const override { + bool equals(const Expression &Other) const override { const ConstantExpression &OC = cast<ConstantExpression>(Other); return ConstantValue == OC.ConstantValue; } - virtual hash_code getHashValue() const override { - return hash_combine(getExpressionType(), ConstantValue->getType(), - ConstantValue); + hash_code getHashValue() const override { + return hash_combine(this->Expression::getHashValue(), + ConstantValue->getType(), ConstantValue); } // // Debugging support // - virtual void printInternal(raw_ostream &OS, bool PrintEType) const override { + void printInternal(raw_ostream &OS, bool PrintEType) const override { if (PrintEType) OS << "ExpressionTypeConstant, "; this->Expression::printInternal(OS, false); @@ -571,35 +595,40 @@ private: Instruction *Inst; public: + UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {} + UnknownExpression() = delete; + UnknownExpression(const UnknownExpression &) = delete; + UnknownExpression &operator=(const UnknownExpression &) = delete; + static bool classof(const Expression *EB) { return EB->getExpressionType() == ET_Unknown; } - UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {} - void operator=(const UnknownExpression &) = delete; - UnknownExpression(const UnknownExpression &) = delete; - UnknownExpression() = delete; - Instruction *getInstruction() const { return Inst; } void setInstruction(Instruction *I) { Inst = I; } - virtual bool equals(const Expression &Other) const override { + + bool equals(const Expression &Other) const override { const auto &OU = cast<UnknownExpression>(Other); return Inst == OU.Inst; } - virtual hash_code getHashValue() const override { - return hash_combine(getExpressionType(), Inst); + + hash_code getHashValue() const override { + return hash_combine(this->Expression::getHashValue(), Inst); } + // // Debugging support // - virtual void printInternal(raw_ostream &OS, bool PrintEType) const override { + void printInternal(raw_ostream &OS, bool PrintEType) const override { if (PrintEType) OS << "ExpressionTypeUnknown, "; this->Expression::printInternal(OS, false); OS << " inst = " << *Inst; } }; -} -} -#endif +} // end namespace GVNExpression + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H |