summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r--llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp1531
1 files changed, 1282 insertions, 249 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 0ff6ee8bcfcc2..90314b17b5e25 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -9,8 +9,11 @@
// Lower matrix intrinsics to vector operations.
//
// TODO:
-// * Implement multiply & add fusion
-// * Add remark, summarizing the available matrix optimization opportunities.
+// * Improve fusion:
+// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
+// transposed.
+// * Improve cost-modeling, e.g. choose different number of rows/columns
+// columns for tiles, consider cost of copies on alias.
//
//===----------------------------------------------------------------------===//
@@ -18,10 +21,15 @@
#include "llvm/ADT/GraphTraits.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
@@ -29,30 +37,69 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
+#include "llvm/Support/Alignment.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
using namespace llvm;
using namespace PatternMatch;
#define DEBUG_TYPE "lower-matrix-intrinsics"
-static cl::opt<bool> EnableShapePropagation("matrix-propagate-shape",
- cl::init(true));
-
+static cl::opt<bool> EnableShapePropagation(
+ "matrix-propagate-shape", cl::init(true), cl::Hidden,
+ cl::desc("Enable/disable shape propagation from matrix intrinsics to other "
+ "instructions."));
+
+static cl::opt<bool>
+ FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
+ cl::desc("Enable/disable fusing matrix instructions."));
+// TODO: Allow and use non-square tiles.
+static cl::opt<unsigned> TileSize(
+ "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
+ cl::desc(
+ "Tile size for matrix instruction fusion using square-shaped tiles."));
+static cl::opt<bool> ForceFusion(
+ "force-fuse-matrix", cl::init(false), cl::Hidden,
+ cl::desc("Force matrix instruction fusion even if not profitable."));
static cl::opt<bool> AllowContractEnabled(
"matrix-allow-contract", cl::init(false), cl::Hidden,
cl::desc("Allow the use of FMAs if available and profitable. This may "
"result in different results, due to less rounding error."));
+enum class MatrixLayoutTy { ColumnMajor, RowMajor };
+
+static cl::opt<MatrixLayoutTy> MatrixLayout(
+ "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
+ cl::desc("Sets the default matrix layout"),
+ cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
+ "Use column-major layout"),
+ clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
+ "Use row-major layout")));
+
+/// Helper function to either return Scope, if it is a subprogram or the
+/// attached subprogram for a local scope.
+static DISubprogram *getSubprogram(DIScope *Scope) {
+ if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
+ return Subprogram;
+ return cast<DILocalScope>(Scope)->getSubprogram();
+}
+
namespace {
-// Given an element poitner \p BasePtr to the start of a (sub) matrix, compute
-// the start address of column \p Col with type (\p EltType x \p NumRows)
-// assuming \p Stride elements between start two consecutive columns.
-// \p Stride must be >= \p NumRows.
+// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
+// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
+// assuming \p Stride elements between start two consecutive vectors.
+// \p Stride must be >= \p NumElements.
+// For column-major matrixes, the function computes the address of a column
+// vectors and \p NumElements must be set to the number of elements in a column
+// (= number of rows of the matrix). For row-major matrixes, the function
+// computes the address of a row vector and \p NumElements must be set to the
+// number of elements in a column (= number of columns of the matrix).
//
-// Consider a 4x4 matrix like below
+// Consider a 4x4 matrix in column-mjaor layout like below
//
// 0 1 2 3
// 0 v_0_0 v_0_1 v_0_2 v_0_3
@@ -62,14 +109,14 @@ namespace {
// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
// we need a pointer to the first element of the submatrix as base pointer.
-// Then we can use computeColumnAddr to compute the addresses for the columns
+// Then we can use computeVectorAddr to compute the addresses for the columns
// of the sub-matrix.
//
-// Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
+// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
// -> just returns Base
-// Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
+// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
// -> returns Base + (1 * 4)
-// Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
+// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
// -> returns Base + (2 * 4)
//
// The graphic below illustrates the number of elements in a column (marked
@@ -82,30 +129,30 @@ namespace {
// v_2_0 |v_2_1 |v_2_2 |v_2_3
// v_3_0 {v_3_1 {v_3_2 v_3_3
//
-Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,
- unsigned NumRows, Type *EltType,
+Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
+ unsigned NumElements, Type *EltType,
IRBuilder<> &Builder) {
assert((!isa<ConstantInt>(Stride) ||
- cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) &&
- "Stride must be >= the number of rows.");
+ cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
+ "Stride must be >= the number of elements in the result vector.");
unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
- // Compute the start of the column with index Col as Col * Stride.
- Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start");
+ // Compute the start of the vector with index VecIdx as VecIdx * Stride.
+ Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
- // Get pointer to the start of the selected column. Skip GEP creation,
- // if we select column 0.
- if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero())
- ColumnStart = BasePtr;
+ // Get pointer to the start of the selected vector. Skip GEP creation,
+ // if we select vector 0.
+ if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
+ VecStart = BasePtr;
else
- ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep");
+ VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
- // Cast elementwise column start pointer to a pointer to a column
- // (EltType x NumRows)*.
- Type *ColumnType = VectorType::get(EltType, NumRows);
- Type *ColumnPtrType = PointerType::get(ColumnType, AS);
- return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast");
+ // Cast elementwise vector start pointer to a pointer to a vector
+ // (EltType x NumElements)*.
+ auto *VecType = FixedVectorType::get(EltType, NumElements);
+ Type *VecPtrType = PointerType::get(VecType, AS);
+ return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
}
/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
@@ -113,15 +160,16 @@ Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,
/// Currently, the lowering for each matrix intrinsic is done as follows:
/// 1. Propagate the shape information from intrinsics to connected
/// instructions.
-/// 2. Lower instructions with shape information.
+/// 2. Lower instructions with shape information (assuming column-major layout).
+/// The lowering works similarly using row-major layout.
/// 2.1. Get column vectors for each argument. If we already lowered the
/// definition of an argument, use the produced column vectors directly.
/// If not, split the operand vector containing an embedded matrix into
/// a set of column vectors,
-/// 2.2. Lower the instruction in terms of columnwise operations, which yields
-/// a set of column vectors containing result matrix. Note that we lower
-/// all instructions that have shape information. Besides the intrinsics,
-/// this includes stores for example.
+/// 2.2. Lower the instruction in terms of column major operations, which
+/// yields a set of column vectors containing result matrix. Note that we
+/// lower all instructions that have shape information. Besides the
+/// intrinsics, this includes stores for example.
/// 2.3. Update uses of the lowered instruction. If we have shape information
/// for a user, there is nothing to do, as we will look up the result
/// column matrix when lowering the user. For other uses, we embed the
@@ -134,42 +182,157 @@ class LowerMatrixIntrinsics {
Function &Func;
const DataLayout &DL;
const TargetTransformInfo &TTI;
+ AliasAnalysis &AA;
+ DominatorTree &DT;
+ LoopInfo &LI;
+ OptimizationRemarkEmitter &ORE;
+
+ /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
+ struct OpInfoTy {
+ /// Number of stores emitted to generate this matrix.
+ unsigned NumStores = 0;
+ /// Number of loads emitted to generate this matrix.
+ unsigned NumLoads = 0;
+ /// Number of compute operations emitted to generate this matrix.
+ unsigned NumComputeOps = 0;
+
+ OpInfoTy &operator+=(const OpInfoTy &RHS) {
+ NumStores += RHS.NumStores;
+ NumLoads += RHS.NumLoads;
+ NumComputeOps += RHS.NumComputeOps;
+ return *this;
+ }
+ };
+
+ /// Wrapper class representing a matrix as a set of vectors, either in row or
+ /// column major layout. All vectors must have the same vector type.
+ class MatrixTy {
+ SmallVector<Value *, 16> Vectors;
+
+ OpInfoTy OpInfo;
- /// Wrapper class representing a matrix as a set of column vectors.
- /// All column vectors must have the same vector type.
- class ColumnMatrixTy {
- SmallVector<Value *, 16> Columns;
+ bool IsColumnMajor = true;
public:
- ColumnMatrixTy() : Columns() {}
- ColumnMatrixTy(ArrayRef<Value *> Cols)
- : Columns(Cols.begin(), Cols.end()) {}
+ MatrixTy()
+ : Vectors(),
+ IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
+ MatrixTy(ArrayRef<Value *> Vectors)
+ : Vectors(Vectors.begin(), Vectors.end()),
+ IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
+ MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
+ : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
+
+ unsigned D = isColumnMajor() ? NumColumns : NumRows;
+ for (unsigned J = 0; J < D; ++J)
+ addVector(UndefValue::get(FixedVectorType::get(
+ EltTy, isColumnMajor() ? NumRows : NumColumns)));
+ }
+
+ Value *getVector(unsigned i) const { return Vectors[i]; }
+ Value *getColumn(unsigned i) const {
+ assert(isColumnMajor() && "only supported for column-major matrixes");
+ return Vectors[i];
+ }
+ Value *getRow(unsigned i) const {
+ assert(!isColumnMajor() && "only supported for row-major matrixes");
+ return Vectors[i];
+ }
- Value *getColumn(unsigned i) const { return Columns[i]; }
+ void setVector(unsigned i, Value *V) { Vectors[i] = V; }
- void setColumn(unsigned i, Value *V) { Columns[i] = V; }
+ Type *getElementType() { return getVectorTy()->getElementType(); }
- size_t getNumColumns() const { return Columns.size(); }
- size_t getNumRows() const {
- assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
- return cast<VectorType>(Columns[0]->getType())->getNumElements();
+ unsigned getNumVectors() const {
+ if (isColumnMajor())
+ return getNumColumns();
+ return getNumRows();
}
- const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; }
+ unsigned getNumColumns() const {
+ if (isColumnMajor())
+ return Vectors.size();
+ else {
+ assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
+ return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
+ }
+ }
+ unsigned getNumRows() const {
+ if (isColumnMajor()) {
+ assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
+ return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
+ } else
+ return Vectors.size();
+ }
- SmallVectorImpl<Value *> &getColumnVectors() { return Columns; }
+ void addVector(Value *V) { Vectors.push_back(V); }
+ VectorType *getColumnTy() {
+ assert(isColumnMajor() && "only supported for column-major matrixes");
+ return getVectorTy();
+ }
- void addColumn(Value *V) { Columns.push_back(V); }
+ VectorType *getVectorTy() {
+ return cast<VectorType>(Vectors[0]->getType());
+ }
iterator_range<SmallVector<Value *, 8>::iterator> columns() {
- return make_range(Columns.begin(), Columns.end());
+ assert(isColumnMajor() &&
+ "columns() only supported for column-major matrixes");
+ return make_range(Vectors.begin(), Vectors.end());
}
- /// Embed the columns of the matrix into a flat vector by concatenating
+ iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
+ return make_range(Vectors.begin(), Vectors.end());
+ }
+
+ /// Embed the vectors of the matrix into a flat vector by concatenating
/// them.
Value *embedInVector(IRBuilder<> &Builder) const {
- return Columns.size() == 1 ? Columns[0]
- : concatenateVectors(Builder, Columns);
+ return Vectors.size() == 1 ? Vectors[0]
+ : concatenateVectors(Builder, Vectors);
+ }
+
+ MatrixTy &addNumLoads(unsigned N) {
+ OpInfo.NumLoads += N;
+ return *this;
+ }
+
+ void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
+
+ MatrixTy &addNumStores(unsigned N) {
+ OpInfo.NumStores += N;
+ return *this;
+ }
+
+ MatrixTy &addNumComputeOps(unsigned N) {
+ OpInfo.NumComputeOps += N;
+ return *this;
+ }
+
+ unsigned getNumStores() const { return OpInfo.NumStores; }
+ unsigned getNumLoads() const { return OpInfo.NumLoads; }
+ unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
+
+ const OpInfoTy &getOpInfo() const { return OpInfo; }
+
+ bool isColumnMajor() const { return IsColumnMajor; }
+
+ unsigned getStride() const {
+ if (isColumnMajor())
+ return getNumRows();
+ return getNumColumns();
+ }
+
+ /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
+ /// matrix is column-major, the result vector is extracted from a column
+ /// vector, otherwise from a row vector.
+ Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
+ IRBuilder<> &Builder) const {
+ Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
+ Value *Undef = UndefValue::get(Vec->getType());
+ return Builder.CreateShuffleVector(
+ Vec, Undef, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
+ "block");
}
};
@@ -177,12 +340,15 @@ class LowerMatrixIntrinsics {
unsigned NumRows;
unsigned NumColumns;
+ bool IsColumnMajor;
+
ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
- : NumRows(NumRows), NumColumns(NumColumns) {}
+ : NumRows(NumRows), NumColumns(NumColumns),
+ IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
ShapeInfo(Value *NumRows, Value *NumColumns)
- : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()),
- NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {}
+ : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
+ cast<ConstantInt>(NumColumns)->getZExtValue()) {}
bool operator==(const ShapeInfo &other) {
return NumRows == other.NumRows && NumColumns == other.NumColumns;
@@ -195,12 +361,24 @@ class LowerMatrixIntrinsics {
assert(NumRows == 0 || NumColumns != 0);
return NumRows != 0;
}
+
+ unsigned getStride() const {
+ if (IsColumnMajor)
+ return NumRows;
+ return NumColumns;
+ }
+
+ unsigned getNumVectors() const {
+ if (IsColumnMajor)
+ return NumColumns;
+ return NumRows;
+ }
};
/// Maps instructions to their shape information. The shape information
/// describes the shape to be used while lowering. This matches the shape of
/// the result value of the instruction, with the only exceptions being store
- /// instructions and the matrix_columnwise_store intrinsics. For those, the
+ /// instructions and the matrix_column_major_store intrinsics. For those, the
/// shape information indicates that those instructions should be lowered
/// using shape information as well.
DenseMap<Value *, ShapeInfo> ShapeMap;
@@ -211,31 +389,49 @@ class LowerMatrixIntrinsics {
SmallVector<Instruction *, 16> ToRemove;
/// Map from instructions to their produced column matrix.
- DenseMap<Value *, ColumnMatrixTy> Inst2ColumnMatrix;
+ MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
public:
- LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI)
- : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {}
+ LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
+ AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI,
+ OptimizationRemarkEmitter &ORE)
+ : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
+ LI(LI), ORE(ORE) {}
+
+ unsigned getNumOps(Type *VT) {
+ assert(isa<VectorType>(VT) && "Expected vector type");
+ return getNumOps(VT->getScalarType(),
+ cast<FixedVectorType>(VT)->getNumElements());
+ }
- /// Return the set of column vectors that a matrix value is lowered to.
+ //
+ /// Return the estimated number of vector ops required for an operation on
+ /// \p VT * N.
+ unsigned getNumOps(Type *ST, unsigned N) {
+ return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
+ double(TTI.getRegisterBitWidth(true)));
+ }
+
+ /// Return the set of vectors that a matrix value is lowered to.
///
- /// If we lowered \p MatrixVal, just return the cache result column matrix.
- /// Otherwie split the flat vector \p MatrixVal containing a matrix with
- /// shape \p SI into column vectors.
- ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
- IRBuilder<> Builder) {
+ /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
+ /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
+ /// into vectors.
+ MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
+ IRBuilder<> &Builder) {
VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
assert(VType && "MatrixVal must be a vector type");
- assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
+ assert(cast<FixedVectorType>(VType)->getNumElements() ==
+ SI.NumRows * SI.NumColumns &&
"The vector size must match the number of matrix elements");
// Check if we lowered MatrixVal using shape information. In that case,
- // return the existing column matrix, if it matches the requested shape
+ // return the existing matrix, if it matches the requested shape
// information. If there is a mis-match, embed the result in a flat
// vector and split it later.
auto Found = Inst2ColumnMatrix.find(MatrixVal);
if (Found != Inst2ColumnMatrix.end()) {
- ColumnMatrixTy &M = Found->second;
+ MatrixTy &M = Found->second;
// Return the found matrix, if its shape matches the requested shape
// information
if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
@@ -247,10 +443,12 @@ public:
// Otherwise split MatrixVal.
SmallVector<Value *, 16> SplitVecs;
Value *Undef = UndefValue::get(VType);
- for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
- MaskStart += SI.NumRows) {
- Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0);
- Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split");
+ for (unsigned MaskStart = 0;
+ MaskStart < cast<FixedVectorType>(VType)->getNumElements();
+ MaskStart += SI.getStride()) {
+ Value *V = Builder.CreateShuffleVector(
+ MatrixVal, Undef, createSequentialMask(MaskStart, SI.getStride(), 0),
+ "split");
SplitVecs.push_back(V);
}
@@ -308,8 +506,8 @@ public:
switch (II->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
case Intrinsic::matrix_transpose:
- case Intrinsic::matrix_columnwise_load:
- case Intrinsic::matrix_columnwise_store:
+ case Intrinsic::matrix_column_major_load:
+ case Intrinsic::matrix_column_major_store:
return true;
default:
return false;
@@ -348,13 +546,13 @@ public:
m_Value(MatrixA), m_Value(M), m_Value(N)))) {
// Flip dimensions.
Propagate = setShapeInfo(Inst, {N, M});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
+ } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
m_Value(MatrixA), m_Value(), m_Value(),
- m_Value(M), m_Value(N)))) {
+ m_Value(), m_Value(M), m_Value(N)))) {
Propagate = setShapeInfo(Inst, {N, M});
- } else if (match(Inst,
- m_Intrinsic<Intrinsic::matrix_columnwise_load>(
- m_Value(), m_Value(), m_Value(M), m_Value(N)))) {
+ } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
+ m_Value(), m_Value(), m_Value(), m_Value(M),
+ m_Value(N)))) {
Propagate = setShapeInfo(Inst, {M, N});
} else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
auto OpShape = ShapeMap.find(MatrixA);
@@ -426,14 +624,14 @@ public:
// Flip dimensions.
if (setShapeInfo(MatrixA, {M, N}))
pushInstruction(MatrixA, WorkList);
- } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
- m_Value(MatrixA), m_Value(), m_Value(),
+ } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
+ m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
m_Value(M), m_Value(N)))) {
if (setShapeInfo(MatrixA, {M, N})) {
pushInstruction(MatrixA, WorkList);
}
} else if (isa<LoadInst>(V) ||
- match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) {
+ match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
// Nothing to do, no matrix input.
} else if (isa<StoreInst>(V)) {
// Nothing to do. We forward-propagated to this so we would just
@@ -472,8 +670,8 @@ public:
switch (II->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
case Intrinsic::matrix_transpose:
- case Intrinsic::matrix_columnwise_load:
- case Intrinsic::matrix_columnwise_store:
+ case Intrinsic::matrix_column_major_load:
+ case Intrinsic::matrix_column_major_store:
WorkList.push_back(&Inst);
break;
default:
@@ -487,45 +685,57 @@ public:
}
}
- ReversePostOrderTraversal<Function *> RPOT(&Func);
bool Changed = false;
- for (auto *BB : RPOT) {
- for (Instruction &Inst : make_early_inc_range(*BB)) {
- IRBuilder<> Builder(&Inst);
-
- if (CallInst *CInst = dyn_cast<CallInst>(&Inst))
- Changed |= VisitCallInst(CInst);
-
- Value *Op1;
- Value *Op2;
- if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst))
- Changed |= VisitBinaryOperator(BinOp);
- if (match(&Inst, m_Load(m_Value(Op1))))
- Changed |= VisitLoad(&Inst, Op1, Builder);
- else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
- Changed |= VisitStore(&Inst, Op1, Op2, Builder);
+ SmallVector<CallInst *, 16> MaybeFusableInsts;
+ SmallVector<Instruction *, 16> MatrixInsts;
+
+ // First, collect all instructions with shape information and candidates for
+ // fusion (currently only matrix multiplies).
+ ReversePostOrderTraversal<Function *> RPOT(&Func);
+ for (auto *BB : RPOT)
+ for (Instruction &I : *BB) {
+ if (ShapeMap.find(&I) == ShapeMap.end())
+ continue;
+ if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
+ MaybeFusableInsts.push_back(cast<CallInst>(&I));
+ MatrixInsts.push_back(&I);
}
+
+ // Second, try to fuse candidates.
+ SmallPtrSet<Instruction *, 16> FusedInsts;
+ for (CallInst *CI : MaybeFusableInsts)
+ LowerMatrixMultiplyFused(CI, FusedInsts);
+ Changed = !FusedInsts.empty();
+
+ // Third, lower remaining instructions with shape information.
+ for (Instruction *Inst : MatrixInsts) {
+ if (FusedInsts.count(Inst))
+ continue;
+
+ IRBuilder<> Builder(Inst);
+
+ if (CallInst *CInst = dyn_cast<CallInst>(Inst))
+ Changed |= VisitCallInst(CInst);
+
+ Value *Op1;
+ Value *Op2;
+ if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
+ Changed |= VisitBinaryOperator(BinOp);
+ if (match(Inst, m_Load(m_Value(Op1))))
+ Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
+ else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
+ Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
}
+ RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func);
+ RemarkGen.emitRemarks();
+
for (Instruction *Inst : reverse(ToRemove))
Inst->eraseFromParent();
return Changed;
}
- LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType,
- IRBuilder<> Builder) {
- unsigned Align = DL.getABITypeAlignment(EltType);
- return Builder.CreateAlignedLoad(ColumnPtr, Align, "col.load");
- }
-
- StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr,
- Type *EltType, IRBuilder<> Builder) {
- unsigned Align = DL.getABITypeAlignment(EltType);
- return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align);
- }
-
-
/// Turns \p BasePtr into an elementwise pointer to \p EltType.
Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
@@ -545,11 +755,11 @@ public:
case Intrinsic::matrix_transpose:
LowerTranspose(Inst);
break;
- case Intrinsic::matrix_columnwise_load:
- LowerColumnwiseLoad(Inst);
+ case Intrinsic::matrix_column_major_load:
+ LowerColumnMajorLoad(Inst);
break;
- case Intrinsic::matrix_columnwise_store:
- LowerColumnwiseStore(Inst);
+ case Intrinsic::matrix_column_major_store:
+ LowerColumnMajorStore(Inst);
break;
default:
return false;
@@ -557,108 +767,200 @@ public:
return true;
}
- void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
- ShapeInfo Shape) {
- IRBuilder<> Builder(Inst);
- auto VType = cast<VectorType>(Inst->getType());
+ /// Compute the alignment for a column/row \p Idx with \p Stride between them.
+ /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
+ /// ConstantInt, reduce the initial alignment based on the byte offset. For
+ /// non-ConstantInt strides, return the common alignment of the initial
+ /// alignment and the element size in bytes.
+ Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
+ MaybeAlign A) const {
+ Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
+ if (Idx == 0)
+ return InitialAlign;
+
+ TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
+ if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
+ uint64_t StrideInBytes =
+ ConstStride->getZExtValue() * ElementSizeInBits / 8;
+ return commonAlignment(InitialAlign, Idx * StrideInBytes);
+ }
+ return commonAlignment(InitialAlign, ElementSizeInBits / 8);
+ }
+
+ /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
+ /// vectors.
+ MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
+ bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
+ auto VType = cast<VectorType>(Ty);
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
- ColumnMatrixTy Result;
- // Distance between start of one column and the start of the next
- for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) {
- Value *GEP =
- computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows,
- VType->getElementType(), Builder);
- Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
- Result.addColumn(Column);
+ MatrixTy Result;
+ for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
+ Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride,
+ Shape.getStride(), VType->getElementType(),
+ Builder);
+ Value *Vector = Builder.CreateAlignedLoad(
+ GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign),
+ IsVolatile, "col.load");
+
+ Result.addVector(Vector);
}
+ return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
+ Result.getNumVectors());
+ }
- finalizeLowering(Inst, Result, Builder);
+ /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
+ /// starting at \p MatrixPtr[I][J].
+ MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
+ ShapeInfo MatrixShape, Value *I, Value *J,
+ ShapeInfo ResultShape, Type *EltTy,
+ IRBuilder<> &Builder) {
+
+ Value *Offset = Builder.CreateAdd(
+ Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
+
+ unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
+ Value *EltPtr =
+ Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
+ Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+ auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
+ ResultShape.NumColumns);
+ Type *TilePtrTy = PointerType::get(TileTy, AS);
+ Value *TilePtr =
+ Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
+
+ return loadMatrix(TileTy, TilePtr, Align,
+ Builder.getInt64(MatrixShape.getStride()), IsVolatile,
+ ResultShape, Builder);
+ }
+
+ /// Lower a load instruction with shape information.
+ void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
+ bool IsVolatile, ShapeInfo Shape) {
+ IRBuilder<> Builder(Inst);
+ finalizeLowering(Inst,
+ loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
+ Shape, Builder),
+ Builder);
}
- /// Lowers llvm.matrix.columnwise.load.
+ /// Lowers llvm.matrix.column.major.load.
///
/// The intrinsic loads a matrix from memory using a stride between columns.
- void LowerColumnwiseLoad(CallInst *Inst) {
+ void LowerColumnMajorLoad(CallInst *Inst) {
+ assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
+ "Intrinsic only supports column-major layout!");
Value *Ptr = Inst->getArgOperand(0);
Value *Stride = Inst->getArgOperand(1);
- LowerLoad(Inst, Ptr, Stride,
- {Inst->getArgOperand(2), Inst->getArgOperand(3)});
+ LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
+ cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
+ {Inst->getArgOperand(3), Inst->getArgOperand(4)});
}
- void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
- ShapeInfo Shape) {
- IRBuilder<> Builder(Inst);
- auto VType = cast<VectorType>(Matrix->getType());
+ /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
+ /// MatrixPtr[I][J].
+ void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
+ MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
+ Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
+ Value *Offset = Builder.CreateAdd(
+ Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
+
+ unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
+ Value *EltPtr =
+ Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
+ Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+ auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
+ StoreVal.getNumColumns());
+ Type *TilePtrTy = PointerType::get(TileTy, AS);
+ Value *TilePtr =
+ Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
+
+ storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
+ Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
+ }
+
+ /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
+ /// vectors.
+ MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
+ MaybeAlign MAlign, Value *Stride, bool IsVolatile,
+ IRBuilder<> &Builder) {
+ auto VType = cast<VectorType>(Ty);
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
- auto LM = getMatrix(Matrix, Shape, Builder);
- for (auto C : enumerate(LM.columns())) {
- Value *GEP =
- computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
- Shape.NumRows, VType->getElementType(), Builder);
- createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
+ for (auto Vec : enumerate(StoreVal.vectors())) {
+ Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()),
+ Stride, StoreVal.getStride(),
+ VType->getElementType(), Builder);
+ Builder.CreateAlignedStore(Vec.value(), GEP,
+ getAlignForIndex(Vec.index(), Stride,
+ VType->getElementType(),
+ MAlign),
+ IsVolatile);
}
+ return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
+ StoreVal.getNumVectors());
+ }
- ToRemove.push_back(Inst);
+ /// Lower a store instruction with shape information.
+ void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
+ Value *Stride, bool IsVolatile, ShapeInfo Shape) {
+ IRBuilder<> Builder(Inst);
+ auto StoreVal = getMatrix(Matrix, Shape, Builder);
+ finalizeLowering(Inst,
+ storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
+ IsVolatile, Builder),
+ Builder);
}
- /// Lowers llvm.matrix.columnwise.store.
+ /// Lowers llvm.matrix.column.major.store.
///
/// The intrinsic store a matrix back memory using a stride between columns.
- void LowerColumnwiseStore(CallInst *Inst) {
+ void LowerColumnMajorStore(CallInst *Inst) {
+ assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
+ "Intrinsic only supports column-major layout!");
Value *Matrix = Inst->getArgOperand(0);
Value *Ptr = Inst->getArgOperand(1);
Value *Stride = Inst->getArgOperand(2);
- LowerStore(Inst, Matrix, Ptr, Stride,
- {Inst->getArgOperand(3), Inst->getArgOperand(4)});
- }
-
- /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from
- /// the matrix \p LM represented as a vector of column vectors.
- Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J,
- unsigned NumElts, IRBuilder<> Builder) {
- Value *Col = LM.getColumn(J);
- Value *Undef = UndefValue::get(Col->getType());
- Constant *Mask = createSequentialMask(Builder, I, NumElts, 0);
- return Builder.CreateShuffleVector(Col, Undef, Mask, "block");
+ LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
+ cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
+ {Inst->getArgOperand(4), Inst->getArgOperand(5)});
}
// Set elements I..I+NumElts-1 to Block
Value *insertVector(Value *Col, unsigned I, Value *Block,
- IRBuilder<> Builder) {
+ IRBuilder<> &Builder) {
// First, bring Block to the same size as Col
unsigned BlockNumElts =
- cast<VectorType>(Block->getType())->getNumElements();
- unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements();
+ cast<FixedVectorType>(Block->getType())->getNumElements();
+ unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
assert(NumElts >= BlockNumElts && "Too few elements for current block");
- Value *ExtendMask =
- createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts);
Value *Undef = UndefValue::get(Block->getType());
- Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask);
+ Block = Builder.CreateShuffleVector(
+ Block, Undef,
+ createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
// If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
// 8, 4, 5, 6
- SmallVector<Constant *, 16> Mask;
+ SmallVector<int, 16> Mask;
unsigned i;
for (i = 0; i < I; i++)
- Mask.push_back(Builder.getInt32(i));
+ Mask.push_back(i);
- unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements();
+ unsigned VecNumElts =
+ cast<FixedVectorType>(Col->getType())->getNumElements();
for (; i < I + BlockNumElts; i++)
- Mask.push_back(Builder.getInt32(i - I + VecNumElts));
+ Mask.push_back(i - I + VecNumElts);
for (; i < VecNumElts; i++)
- Mask.push_back(Builder.getInt32(i));
-
- Value *MaskVal = ConstantVector::get(Mask);
+ Mask.push_back(i);
- return Builder.CreateShuffleVector(Col, Block, MaskVal);
+ return Builder.CreateShuffleVector(Col, Block, Mask);
}
Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
- IRBuilder<> &Builder, bool AllowContraction) {
-
+ IRBuilder<> &Builder, bool AllowContraction,
+ unsigned &NumComputeOps) {
+ NumComputeOps += getNumOps(A->getType());
if (!Sum)
return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
@@ -666,14 +968,16 @@ public:
if (AllowContraction) {
// Use fmuladd for floating point operations and let the backend decide
// if that's profitable.
- Value *FMulAdd = Intrinsic::getDeclaration(
+ Function *FMulAdd = Intrinsic::getDeclaration(
Func.getParent(), Intrinsic::fmuladd, A->getType());
return Builder.CreateCall(FMulAdd, {A, B, Sum});
}
+ NumComputeOps += getNumOps(A->getType());
Value *Mul = Builder.CreateFMul(A, B);
return Builder.CreateFAdd(Sum, Mul);
}
+ NumComputeOps += getNumOps(A->getType());
Value *Mul = Builder.CreateMul(A, B);
return Builder.CreateAdd(Sum, Mul);
}
@@ -683,7 +987,7 @@ public:
/// cached value when they are lowered. For other users, \p Matrix is
/// flattened and the uses are updated to use it. Also marks \p Inst for
/// deletion.
- void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix,
+ void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
IRBuilder<> &Builder) {
Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
@@ -699,6 +1003,294 @@ public:
}
}
+ /// Compute \p Result += \p A * \p B for input matrices with left-associating
+ /// addition.
+ void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
+ const MatrixTy &B, bool AllowContraction,
+ IRBuilder<> &Builder, bool isTiled) {
+ const unsigned VF = std::max<unsigned>(
+ TTI.getRegisterBitWidth(true) /
+ Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
+ 1U);
+ unsigned R = Result.getNumRows();
+ unsigned C = Result.getNumColumns();
+ unsigned M = A.getNumColumns();
+
+ bool IsFP = Result.getElementType()->isFloatingPointTy();
+ assert(A.isColumnMajor() == B.isColumnMajor() &&
+ Result.isColumnMajor() == A.isColumnMajor() &&
+ "operands must agree on matrix layout");
+ unsigned NumComputeOps = 0;
+ if (A.isColumnMajor()) {
+ // Multiply columns from the first operand with scalars from the second
+ // operand. Then move along the K axes and accumulate the columns. With
+ // this the adds can be vectorized without reassociation.
+ for (unsigned J = 0; J < C; ++J) {
+ unsigned BlockSize = VF;
+ // If Result is zero, we don't need to accumulate in the K==0 iteration.
+ bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
+
+ for (unsigned I = 0; I < R; I += BlockSize) {
+ // Gradually lower the vectorization factor to cover the remainder.
+ while (I + BlockSize > R)
+ BlockSize /= 2;
+
+ Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder)
+ : nullptr;
+ for (unsigned K = 0; K < M; ++K) {
+ Value *L = A.extractVector(I, K, BlockSize, Builder);
+ Value *RH = Builder.CreateExtractElement(B.getColumn(J), K);
+ Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
+ Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
+ Result.getElementType()->isFloatingPointTy(),
+ Builder, AllowContraction, NumComputeOps);
+ }
+ Result.setVector(J,
+ insertVector(Result.getVector(J), I, Sum, Builder));
+ }
+ }
+ } else {
+ // Multiply rows from the second operand with scalars from the first
+ // operand. Then move along the K axes and accumulate the rows. With this
+ // the adds can be vectorized without reassociation.
+ for (unsigned I = 0; I < R; ++I) {
+ unsigned BlockSize = VF;
+ bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
+ for (unsigned J = 0; J < C; J += BlockSize) {
+ // Gradually lower the vectorization factor to cover the remainder.
+ while (J + BlockSize > C)
+ BlockSize /= 2;
+
+ Value *Sum = nullptr;
+ for (unsigned K = 0; K < M; ++K) {
+ Value *R = B.extractVector(K, J, BlockSize, Builder);
+ Value *LH = Builder.CreateExtractElement(A.getVector(I), K);
+ Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
+ Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
+ IsFP, Builder, AllowContraction, NumComputeOps);
+ }
+ Result.setVector(I,
+ insertVector(Result.getVector(I), J, Sum, Builder));
+ }
+ }
+ }
+ Result.addNumComputeOps(NumComputeOps);
+ }
+
+ /// Ensure that the memory in \p Load does not alias \p Store by potentially
+ /// copying it to a new location. This new or otherwise the original location
+ /// is returned.
+ Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
+ CallInst *MatMul) {
+ MemoryLocation StoreLoc = MemoryLocation::get(Store);
+ MemoryLocation LoadLoc = MemoryLocation::get(Load);
+
+ AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc);
+
+ // If we can statically determine noalias we're good.
+ if (!LdAliased)
+ return Load->getPointerOperand();
+
+ // Create code to check if the memory locations of the Load and Store
+ // overlap and if they do, copy Load's operand to a new buffer.
+
+ // First, create new blocks for 2n part of the check and the copy.
+ BasicBlock *Check0 = MatMul->getParent();
+ // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
+ // DT. Manually collect dominator tree updates, to avoid unnecessary work,
+ // as we adjust Check0 and Check1's branches.
+ SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
+ for (BasicBlock *Succ : successors(Check0))
+ DTUpdates.push_back({DT.Delete, Check0, Succ});
+
+ BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI,
+ nullptr, "alias_cont");
+ BasicBlock *Copy =
+ SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy");
+ BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI,
+ nullptr, "no_alias");
+
+ // Check if the loaded memory location begins before the end of the store
+ // location. If the condition holds, they might overlap, otherwise they are
+ // guaranteed to not overlap.
+ IRBuilder<> Builder(MatMul);
+ Check0->getTerminator()->eraseFromParent();
+ Builder.SetInsertPoint(Check0);
+ Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
+ Value *StoreBegin = Builder.CreatePtrToInt(
+ const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
+ Value *StoreEnd = Builder.CreateAdd(
+ StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
+ "store.end", true, true);
+ Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
+ IntPtrTy, "load.begin");
+ Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
+ Fusion);
+
+ // Check if the store begins before the end of the load location. If the
+ // condition holds, they alias, otherwise they are guaranteed to not
+ // overlap.
+ Check1->getTerminator()->eraseFromParent();
+ Builder.SetInsertPoint(Check1, Check1->begin());
+ Value *LoadEnd = Builder.CreateAdd(
+ LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
+ "load.end", true, true);
+ Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
+ Fusion);
+
+ // Copy load operand to new alloca.
+ Builder.SetInsertPoint(Copy, Copy->begin());
+ AllocaInst *NewLd =
+ Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace());
+ Builder.CreateMemCpy(NewLd, NewLd->getAlign(),
+ Load->getPointerOperand(), Load->getAlign(),
+ LoadLoc.Size.getValue());
+ Builder.SetInsertPoint(Fusion, Fusion->begin());
+ PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
+ PHI->addIncoming(Load->getPointerOperand(), Check0);
+ PHI->addIncoming(Load->getPointerOperand(), Check1);
+ PHI->addIncoming(NewLd, Copy);
+
+ // Adjust DT.
+ DTUpdates.push_back({DT.Insert, Check0, Check1});
+ DTUpdates.push_back({DT.Insert, Check0, Fusion});
+ DTUpdates.push_back({DT.Insert, Check1, Copy});
+ DTUpdates.push_back({DT.Insert, Check1, Fusion});
+ DT.applyUpdates(DTUpdates);
+ return PHI;
+ }
+
+ bool isFusionProfitable(CallInst *MatMul) {
+ if (ForceFusion)
+ return true;
+
+ ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
+ ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
+
+ const unsigned R = LShape.NumRows;
+ const unsigned C = RShape.NumColumns;
+ const unsigned M = LShape.NumColumns;
+ auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+
+ const unsigned VF =
+ std::max<unsigned>(TTI.getRegisterBitWidth(true) /
+ EltType->getPrimitiveSizeInBits().getFixedSize(),
+ 1U);
+
+ // Cost model for tiling
+ //
+ // For tiling to be beneficial, we need reuse either along the R or
+ // the C axis. We vectorize along the R axis so that means at least
+ // 3 elements.
+ // TODO: Also consider cost of copying if operands alias.
+ if (R <= VF && C == 1)
+ return false;
+ // Then we need enough elements to exceed the number of vector
+ // registers we have. Note that this is an oversimplification since
+ // fusing also takes some extra loads which may exceed the number of
+ // reloads necessary.
+ unsigned Op0Regs = (R + VF - 1) / VF * M;
+ unsigned Op1Regs = (M + VF - 1) / VF * C;
+ return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true);
+ }
+
+ MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
+ MatrixTy Res;
+ auto *ColumType = FixedVectorType::get(EltType, R);
+ for (unsigned I = 0; I < C; ++I)
+ Res.addVector(ConstantAggregateZero::get(ColumType));
+ return Res;
+ }
+
+ void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
+ StoreInst *Store,
+ SmallPtrSetImpl<Instruction *> &FusedInsts) {
+ assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
+ "Tiling only supported for column-major matrixes at the moment!");
+ if (!isFusionProfitable(MatMul))
+ return;
+
+ ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
+ ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
+
+ const unsigned R = LShape.NumRows;
+ const unsigned C = RShape.NumColumns;
+ const unsigned M = LShape.NumColumns;
+ auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+
+ Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
+ Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
+ Value *CPtr = Store->getPointerOperand();
+
+ bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
+ MatMul->hasAllowContract());
+ IRBuilder<> Builder(Store);
+ for (unsigned J = 0; J < C; J += TileSize)
+ for (unsigned I = 0; I < R; I += TileSize) {
+ const unsigned TileR = std::min(R - I, unsigned(TileSize));
+ const unsigned TileC = std::min(C - J, unsigned(TileSize));
+ MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
+
+ for (unsigned K = 0; K < M; K += TileSize) {
+ const unsigned TileM = std::min(M - K, unsigned(TileSize));
+ MatrixTy A =
+ loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
+ LShape, Builder.getInt64(I), Builder.getInt64(K),
+ {TileR, TileM}, EltType, Builder);
+ MatrixTy B =
+ loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
+ RShape, Builder.getInt64(K), Builder.getInt64(J),
+ {TileM, TileC}, EltType, Builder);
+ emitMatrixMultiply(Res, A, B, AllowContract, Builder, true);
+ }
+ storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
+ Builder.getInt64(I), Builder.getInt64(J), EltType, Builder);
+ }
+
+ // Mark eliminated instructions as fused and remove them.
+ FusedInsts.insert(Store);
+ FusedInsts.insert(MatMul);
+ Store->eraseFromParent();
+ MatMul->eraseFromParent();
+ if (LoadOp0->hasNUses(0)) {
+ FusedInsts.insert(LoadOp0);
+ LoadOp0->eraseFromParent();
+ }
+ if (LoadOp1->hasNUses(0)) {
+ FusedInsts.insert(LoadOp1);
+ LoadOp1->eraseFromParent();
+ }
+ }
+
+ /// Try to lower matrix multiply chains by fusing operations.
+ ///
+ /// Currently we only lower {ld, ld} -> matmul -> st chains.
+ //
+ /// No need to return a MatrixTy object for the result of the operation, since
+ /// the single store user will be lowered as part of this. Instructions that
+ /// are completely eliminated by fusion are added to \p FusedInsts.
+ void LowerMatrixMultiplyFused(CallInst *MatMul,
+ SmallPtrSetImpl<Instruction *> &FusedInsts) {
+ if (!FuseMatrix || !MatMul->hasOneUse() ||
+ MatrixLayout != MatrixLayoutTy::ColumnMajor)
+ return;
+
+ auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0));
+ auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1));
+ auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
+ if (LoadOp0 && LoadOp1 && Store) {
+ // The store address must dominate the MatMul instruction, otherwise
+ // we create invalid IR.
+ // FIXME: See if we can hoist the store address computation.
+ auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1));
+ if (AddrI && (!DT.dominates(AddrI, MatMul)))
+ return;
+
+ emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
+ return;
+ }
+ }
+
/// Lowers llvm.matrix.multiply.
void LowerMultiply(CallInst *MatMul) {
IRBuilder<> Builder(MatMul);
@@ -706,97 +1298,80 @@ public:
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
- const ColumnMatrixTy &Lhs =
- getMatrix(MatMul->getArgOperand(0), LShape, Builder);
- const ColumnMatrixTy &Rhs =
- getMatrix(MatMul->getArgOperand(1), RShape, Builder);
+ const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
+ const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
const unsigned R = LShape.NumRows;
- const unsigned M = LShape.NumColumns;
const unsigned C = RShape.NumColumns;
- assert(M == RShape.NumRows);
+ assert(LShape.NumColumns == RShape.NumRows);
// Initialize the output
- ColumnMatrixTy Result;
- for (unsigned J = 0; J < C; ++J)
- Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));
-
- const unsigned VF = std::max(TTI.getRegisterBitWidth(true) /
- EltType->getPrimitiveSizeInBits(),
- uint64_t(1));
+ MatrixTy Result(R, C, EltType);
bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
MatMul->hasAllowContract());
- // Multiply columns from the first operand with scalars from the second
- // operand. Then move along the K axes and accumulate the columns. With
- // this the adds can be vectorized without reassociation.
- for (unsigned J = 0; J < C; ++J) {
- unsigned BlockSize = VF;
- for (unsigned I = 0; I < R; I += BlockSize) {
- // Gradually lower the vectorization factor to cover the remainder.
- while (I + BlockSize > R)
- BlockSize /= 2;
-
- Value *Sum = nullptr;
- for (unsigned K = 0; K < M; ++K) {
- Value *L = extractVector(Lhs, I, K, BlockSize, Builder);
- Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
- Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
- Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
- Builder, AllowContract);
- }
- Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
- }
- }
+ emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false);
finalizeLowering(MatMul, Result, Builder);
}
/// Lowers llvm.matrix.transpose.
void LowerTranspose(CallInst *Inst) {
- ColumnMatrixTy Result;
+ MatrixTy Result;
IRBuilder<> Builder(Inst);
Value *InputVal = Inst->getArgOperand(0);
VectorType *VectorTy = cast<VectorType>(InputVal->getType());
ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
- ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
-
- for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) {
- // Build a single column vector for this row. First initialize it.
- Value *ResultColumn = UndefValue::get(
- VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns));
-
- // Go through the elements of this row and insert it into the resulting
- // column vector.
- for (auto C : enumerate(InputMatrix.columns())) {
- Value *Elt = Builder.CreateExtractElement(C.value(), Row);
- // We insert at index Column since that is the row index after the
- // transpose.
- ResultColumn =
- Builder.CreateInsertElement(ResultColumn, Elt, C.index());
+ MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
+
+ const unsigned NewNumVecs =
+ InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
+ const unsigned NewNumElts =
+ InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
+
+ for (unsigned I = 0; I < NewNumVecs; ++I) {
+ // Build a single result vector. First initialize it.
+ Value *ResultVector = UndefValue::get(
+ FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
+ // Go through the old elements and insert it into the resulting vector.
+ for (auto J : enumerate(InputMatrix.vectors())) {
+ Value *Elt = Builder.CreateExtractElement(J.value(), I);
+ // Row and column indices are transposed.
+ ResultVector =
+ Builder.CreateInsertElement(ResultVector, Elt, J.index());
}
- Result.addColumn(ResultColumn);
+ Result.addVector(ResultVector);
}
- finalizeLowering(Inst, Result, Builder);
+ // TODO: Improve estimate of operations needed for transposes. Currently we
+ // just count the insertelement/extractelement instructions, but do not
+ // account for later simplifications/combines.
+ finalizeLowering(
+ Inst,
+ Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns),
+ Builder);
}
/// Lower load instructions, if shape information is available.
- bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) {
+ bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
auto I = ShapeMap.find(Inst);
if (I == ShapeMap.end())
return false;
- LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second);
+ LowerLoad(Inst, Ptr, Inst->getAlign(),
+ Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
+ I->second);
return true;
}
- bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr,
+ bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
IRBuilder<> &Builder) {
auto I = ShapeMap.find(StoredVal);
if (I == ShapeMap.end())
return false;
- LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second);
+ LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
+ Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
+ I->second);
return true;
}
@@ -812,12 +1387,15 @@ public:
IRBuilder<> Builder(Inst);
ShapeInfo &Shape = I->second;
- ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
- ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);
+ MatrixTy Result;
+ MatrixTy A = getMatrix(Lhs, Shape, Builder);
+ MatrixTy B = getMatrix(Rhs, Shape, Builder);
+ assert(A.isColumnMajor() == B.isColumnMajor() &&
+ Result.isColumnMajor() == A.isColumnMajor() &&
+ "operands must agree on matrix layout");
- // Add each column and store the result back into the opmapping
- ColumnMatrixTy Result;
- auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) {
+ // Helper to perform binary op on vectors.
+ auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
switch (Inst->getOpcode()) {
case Instruction::Add:
return Builder.CreateAdd(LHS, RHS);
@@ -835,20 +1413,462 @@ public:
llvm_unreachable("Unsupported binary operator for matrix");
}
};
- for (unsigned C = 0; C < Shape.NumColumns; ++C)
- Result.addColumn(
- BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C)));
- finalizeLowering(Inst, Result, Builder);
+ for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+ Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
+
+ finalizeLowering(Inst,
+ Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+ Result.getNumVectors()),
+ Builder);
return true;
}
+
+ /// Helper to linearize a matrix expression tree into a string. Currently
+ /// matrix expressions are linarized by starting at an expression leaf and
+ /// linearizing bottom up.
+ struct ExprLinearizer {
+ unsigned LengthToBreak = 100;
+ std::string Str;
+ raw_string_ostream Stream;
+ unsigned LineLength = 0;
+ const DataLayout &DL;
+
+ /// Mapping from instructions to matrixes. It is used to identify
+ /// matrix instructions.
+ const MapVector<Value *, MatrixTy> &Inst2Matrix;
+
+ /// Mapping from values to the leaves of all expressions that the value is
+ /// part of.
+ const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
+
+ /// Set of matrix expressions in the scope of a given DISubprogram.
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram;
+
+ /// Leaf node of the expression to linearize.
+ Value *Leaf;
+
+ /// Used to keep track of sub-expressions that get reused while linearizing
+ /// the expression. Re-used sub-expressions are marked as (reused).
+ SmallPtrSet<Value *, 8> ReusedExprs;
+
+ ExprLinearizer(const DataLayout &DL,
+ const MapVector<Value *, MatrixTy> &Inst2Matrix,
+ const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram,
+ Value *Leaf)
+ : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
+ ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
+
+ void indent(unsigned N) {
+ LineLength += N;
+ for (unsigned i = 0; i < N; i++)
+ Stream << " ";
+ }
+
+ void lineBreak() {
+ Stream << "\n";
+ LineLength = 0;
+ }
+
+ void maybeIndent(unsigned Indent) {
+ if (LineLength >= LengthToBreak)
+ lineBreak();
+
+ if (LineLength == 0)
+ indent(Indent);
+ }
+
+ void write(StringRef S) {
+ LineLength += S.size();
+ Stream << S;
+ }
+
+ Value *getUnderlyingObjectThroughLoads(Value *V) {
+ if (Value *Ptr = getPointerOperand(V))
+ return getUnderlyingObjectThroughLoads(Ptr);
+ else if (V->getType()->isPointerTy())
+ return GetUnderlyingObject(V, DL);
+ return V;
+ }
+
+ /// Returns true if \p V is a matrix value in the given subprogram.
+ bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
+
+ /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
+ /// \p SS.
+ void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
+ auto M = Inst2Matrix.find(V);
+ if (M == Inst2Matrix.end())
+ SS << "unknown";
+ else {
+ SS << M->second.getNumRows();
+ SS << "x";
+ SS << M->second.getNumColumns();
+ }
+ }
+
+ /// Write the called function name. Handles calls to llvm.matrix.*
+ /// specially: we write the name, followed by the dimensions of the input
+ /// matrixes, followed by the scalar type name.
+ void writeFnName(CallInst *CI) {
+ if (!CI->getCalledFunction())
+ write("<no called fn>");
+ else {
+ StringRef Name = CI->getCalledFunction()->getName();
+ if (!Name.startswith("llvm.matrix")) {
+ write(Name);
+ return;
+ }
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
+ write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {}))
+ .drop_front(StringRef("llvm.matrix.").size()));
+ write(".");
+ std::string Tmp = "";
+ raw_string_ostream SS(Tmp);
+
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::matrix_multiply:
+ prettyPrintMatrixType(II->getOperand(0), SS);
+ SS << ".";
+ prettyPrintMatrixType(II->getOperand(1), SS);
+ SS << "." << *II->getType()->getScalarType();
+ break;
+ case Intrinsic::matrix_transpose:
+ prettyPrintMatrixType(II->getOperand(0), SS);
+ SS << "." << *II->getType()->getScalarType();
+ break;
+ case Intrinsic::matrix_column_major_load:
+ prettyPrintMatrixType(II, SS);
+ SS << "." << *II->getType()->getScalarType();
+ break;
+ case Intrinsic::matrix_column_major_store:
+ prettyPrintMatrixType(II->getOperand(0), SS);
+ SS << "." << *II->getOperand(0)->getType()->getScalarType();
+ break;
+ default:
+ llvm_unreachable("Unhandled case");
+ }
+ SS.flush();
+ write(Tmp);
+ }
+ }
+
+ unsigned getNumShapeArgs(CallInst *CI) const {
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::matrix_multiply:
+ return 3;
+ case Intrinsic::matrix_transpose:
+ return 2;
+ case Intrinsic::matrix_column_major_load:
+ case Intrinsic::matrix_column_major_store:
+ return 3;
+ default:
+ return 0;
+ }
+ }
+ return 0;
+ }
+
+ /// Special printing for values: for pointers, we print if they refer to an
+ /// (function) external address or a stack address, for other values we
+ /// either print the constant or "scalar"/"matrix" for other values.
+ void write(Value *V) {
+ V = getUnderlyingObjectThroughLoads(V);
+ if (V->getType()->isPointerTy()) {
+ if (isa<AllocaInst>(V)) {
+ Stream << "stack addr";
+ LineLength += StringRef("stack addr").size();
+ } else {
+ Stream << "addr";
+ LineLength += StringRef("addr").size();
+ }
+ if (!V->getName().empty()) {
+ Stream << " %" << V->getName() << "";
+ LineLength += V->getName().size() + 2;
+ }
+ return;
+ }
+
+ std::string Tmp;
+ raw_string_ostream TmpStream(Tmp);
+
+ if (auto *CI = dyn_cast<ConstantInt>(V))
+ TmpStream << CI->getValue();
+ else if (isa<Constant>(V))
+ TmpStream << "constant";
+ else {
+ if (isMatrix(V))
+ TmpStream << "matrix";
+ else
+ TmpStream << "scalar";
+ }
+ TmpStream.flush();
+ Tmp = std::string(StringRef(Tmp).trim());
+ LineLength += Tmp.size();
+ Stream << Tmp;
+ }
+
+ /// Linearize expression \p Expr starting at an indentation of \p Indent.
+ /// Expressions that are re-used multiple times are prefixed with (reused)
+ /// at the re-used root instruction.
+ void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
+ bool ParentShared) {
+ auto *I = cast<Instruction>(Expr);
+ maybeIndent(Indent);
+ SmallVector<Value *, 8> Ops;
+
+ // Is Expr shared with other expression leaves?
+ bool ExprShared = false;
+
+ // Deal with shared subtrees. Mark them as shared, if required.
+ if (!ParentShared) {
+ auto SI = Shared.find(Expr);
+ assert(SI != Shared.end() && SI->second.count(Leaf));
+
+ for (Value *S : SI->second) {
+ if (S == Leaf)
+ continue;
+ DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
+ write("shared with remark at line " + std::to_string(DL.getLine()) +
+ " column " + std::to_string(DL.getCol()) + " (");
+ }
+ ExprShared = SI->second.size() > 1;
+ }
+
+ bool Reused = !ReusedExprs.insert(Expr).second;
+ if (Reused && !ParentReused)
+ write("(reused) ");
+
+ if (auto *CI = dyn_cast<CallInst>(I)) {
+ writeFnName(CI);
+
+ Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
+ } else if (isa<BitCastInst>(Expr)) {
+ // Special case bitcasts, which are used to materialize matrixes from
+ // non-matrix ops.
+ write("matrix");
+ return;
+ } else {
+ Ops.append(I->value_op_begin(), I->value_op_end());
+ write(std::string(I->getOpcodeName()));
+ }
+
+ write(std::string("("));
+
+ unsigned NumOpsToBreak = 1;
+ if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
+ NumOpsToBreak = 2;
+
+ for (Value *Op : Ops) {
+ if (Ops.size() > NumOpsToBreak)
+ lineBreak();
+
+ maybeIndent(Indent + 1);
+ if (isMatrix(Op))
+ linearizeExpr(Op, Indent + 1, Reused, ExprShared);
+ else
+ write(Op);
+ if (Op != Ops.back())
+ write(", ");
+ }
+
+ write(")");
+ }
+
+ const std::string &getResult() {
+ Stream.flush();
+ return Str;
+ }
+ };
+
+ /// Generate remarks for matrix operations in a function. To generate remarks
+ /// for matrix expressions, the following approach is used:
+ /// 1. Use the inlined-at debug information to group matrix operations to the
+ /// DISubprograms they are contained in.
+ /// 2. Collect leaves of matrix expressions (done in
+ /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
+ // mapping. Leaves are lowered matrix instructions without other matrix
+ // users (like stores) in the current subprogram.
+ /// 3. For each leaf, create a remark containing a linearizied version of the
+ /// matrix expression. The expression is linearized by a recursive
+ /// bottom-up traversal of the matrix operands, starting at a leaf. Note
+ /// that multiple leaves can share sub-expressions. Shared subexpressions
+ /// are explicitly marked as shared().
+ struct RemarkGenerator {
+ const MapVector<Value *, MatrixTy> &Inst2Matrix;
+ OptimizationRemarkEmitter &ORE;
+ Function &Func;
+ const DataLayout &DL;
+
+ RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
+ OptimizationRemarkEmitter &ORE, Function &Func)
+ : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
+ DL(Func.getParent()->getDataLayout()) {}
+
+ /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
+ /// instructions in Inst2Matrix returning void or without any users in
+ /// \p ExprsInSubprogram. Currently that should only include stores.
+ SmallVector<Value *, 4>
+ getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
+ SmallVector<Value *, 4> Leaves;
+ for (auto *Expr : ExprsInSubprogram)
+ if (Expr->getType()->isVoidTy() ||
+ !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
+ return ExprsInSubprogram.count(U);
+ }))
+ Leaves.push_back(Expr);
+ return Leaves;
+ }
+
+ /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
+ /// to all visited expressions in \p Shared. Limit the matrix operations to
+ /// the ones in \p ExprsInSubprogram.
+ void collectSharedInfo(Value *Leaf, Value *V,
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram,
+ DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
+
+ if (!ExprsInSubprogram.count(V))
+ return;
+
+ auto I = Shared.insert({V, {}});
+ I.first->second.insert(Leaf);
+
+ for (Value *Op : cast<Instruction>(V)->operand_values())
+ collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
+ return;
+ }
+
+ /// Calculate the number of exclusive and shared op counts for expression
+ /// starting at \p V. Expressions used multiple times are counted once.
+ /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
+ std::pair<OpInfoTy, OpInfoTy>
+ sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram,
+ DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
+ if (!ExprsInSubprogram.count(Root))
+ return {};
+
+ // Already counted this expression. Stop.
+ if (!ReusedExprs.insert(Root).second)
+ return {};
+
+ OpInfoTy SharedCount;
+ OpInfoTy Count;
+
+ auto I = Shared.find(Root);
+ auto CM = Inst2Matrix.find(Root);
+ if (I->second.size() == 1)
+ Count = CM->second.getOpInfo();
+ else
+ SharedCount = CM->second.getOpInfo();
+
+ for (Value *Op : cast<Instruction>(Root)->operand_values()) {
+ auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
+ Count += C.first;
+ SharedCount += C.second;
+ }
+ return {Count, SharedCount};
+ }
+
+ void emitRemarks() {
+ if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
+ return;
+
+ // Map matrix operations to their containting subprograms, by traversing
+ // the inlinedAt chain. If the function does not have a DISubprogram, we
+ // only map them to the containing function.
+ MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
+ for (auto &KV : Inst2Matrix) {
+ if (Func.getSubprogram()) {
+ auto *I = cast<Instruction>(KV.first);
+ DILocation *Context = I->getDebugLoc();
+ while (Context) {
+ auto I =
+ Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
+ I.first->second.push_back(KV.first);
+ Context = DebugLoc(Context).getInlinedAt();
+ }
+ } else {
+ auto I = Subprog2Exprs.insert({nullptr, {}});
+ I.first->second.push_back(KV.first);
+ }
+ }
+ for (auto &KV : Subprog2Exprs) {
+ SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
+ KV.second.end());
+ auto Leaves = getExpressionLeaves(ExprsInSubprogram);
+
+ DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
+ for (Value *Leaf : Leaves)
+ collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
+
+ // Generate remarks for each leaf.
+ for (auto *L : Leaves) {
+
+ DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
+ DILocation *Context = cast<Instruction>(L)->getDebugLoc();
+ while (Context) {
+ if (getSubprogram(Context->getScope()) == KV.first) {
+ Loc = Context;
+ break;
+ }
+ Context = DebugLoc(Context).getInlinedAt();
+ }
+
+ SmallPtrSet<Value *, 8> ReusedExprs;
+ OpInfoTy Counts, SharedCounts;
+ std::tie(Counts, SharedCounts) =
+ sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
+
+ OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
+ cast<Instruction>(L)->getParent());
+
+ Rem << "Lowered with ";
+ Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
+ << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
+ << ore::NV("NumComputeOps", Counts.NumComputeOps)
+ << " compute ops";
+
+ if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
+ SharedCounts.NumComputeOps > 0) {
+ Rem << ",\nadditionally "
+ << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
+ << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
+ << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
+ << " compute ops"
+ << " are shared with other expressions";
+ }
+
+ Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
+ ORE.emit(Rem);
+ }
+ }
+ }
+
+ std::string
+ linearize(Value *L,
+ const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram,
+ const DataLayout &DL) {
+ ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
+ Lin.linearizeExpr(L, 0, false, false);
+ return Lin.getResult();
+ }
+ };
};
} // namespace
PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
- LowerMatrixIntrinsics LMT(F, TTI);
+ auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
+ auto &AA = AM.getResult<AAManager>(F);
+ auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
+ auto &LI = AM.getResult<LoopAnalysis>(F);
+
+ LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
if (LMT.Visit()) {
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
@@ -869,15 +1889,24 @@ public:
}
bool runOnFunction(Function &F) override {
- auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- LowerMatrixIntrinsics LMT(F, *TTI);
+ auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
+ auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
bool C = LMT.Visit();
return C;
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.setPreservesCFG();
+ AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
+ AU.addRequired<AAResultsWrapperPass>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addPreserved<DominatorTreeWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addPreserved<LoopInfoWrapperPass>();
}
};
} // namespace
@@ -886,6 +1915,10 @@ static const char pass_name[] = "Lower the matrix intrinsics";
char LowerMatrixIntrinsicsLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
false, false)
+INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
false, false)