diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2020-07-26 19:36:28 +0000 | 
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2020-07-26 19:36:28 +0000 | 
| commit | cfca06d7963fa0909f90483b42a6d7d194d01e08 (patch) | |
| tree | 209fb2a2d68f8f277793fc8df46c753d31bc853b /llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | |
| parent | 706b4fc47bbc608932d3b491ae19a3b9cde9497b (diff) | |
Notes
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 1531 | 
1 files changed, 1282 insertions, 249 deletions
| diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 0ff6ee8bcfcc..90314b17b5e2 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -9,8 +9,11 @@  // Lower matrix intrinsics to vector operations.  //  // TODO: -//  * Implement multiply & add fusion -//  * Add remark, summarizing the available matrix optimization opportunities. +//  * Improve fusion: +//   * Support more cases, e.g. multiply-add, multiply-sub, operands/results +//     transposed. +//   * Improve cost-modeling, e.g. choose different number of rows/columns +//     columns for tiles, consider cost of copies on alias.  //  //===----------------------------------------------------------------------===// @@ -18,10 +21,15 @@  #include "llvm/ADT/GraphTraits.h"  #include "llvm/ADT/PostOrderIterator.h"  #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h"  #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h"  #include "llvm/Analysis/VectorUtils.h"  #include "llvm/IR/CFG.h"  #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h"  #include "llvm/IR/Function.h"  #include "llvm/IR/IRBuilder.h"  #include "llvm/IR/Instructions.h" @@ -29,30 +37,69 @@  #include "llvm/IR/PatternMatch.h"  #include "llvm/InitializePasses.h"  #include "llvm/Pass.h" +#include "llvm/Support/Alignment.h" +#include "llvm/Support/CommandLine.h"  #include "llvm/Support/Debug.h"  #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h"  using namespace llvm;  using namespace PatternMatch;  #define DEBUG_TYPE "lower-matrix-intrinsics" -static cl::opt<bool> EnableShapePropagation("matrix-propagate-shape", -                                            cl::init(true)); - +static cl::opt<bool> EnableShapePropagation( +    "matrix-propagate-shape", cl::init(true), cl::Hidden, +    cl::desc("Enable/disable shape propagation from matrix intrinsics to other " +             "instructions.")); + +static cl::opt<bool> +    FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, +               cl::desc("Enable/disable fusing matrix instructions.")); +// TODO: Allow and use non-square tiles. +static cl::opt<unsigned> TileSize( +    "fuse-matrix-tile-size", cl::init(4), cl::Hidden, +    cl::desc( +        "Tile size for matrix instruction fusion using square-shaped tiles.")); +static cl::opt<bool> ForceFusion( +    "force-fuse-matrix", cl::init(false), cl::Hidden, +    cl::desc("Force matrix instruction fusion even if not profitable."));  static cl::opt<bool> AllowContractEnabled(      "matrix-allow-contract", cl::init(false), cl::Hidden,      cl::desc("Allow the use of FMAs if available and profitable. This may "               "result in different results, due to less rounding error.")); +enum class MatrixLayoutTy { ColumnMajor, RowMajor }; + +static cl::opt<MatrixLayoutTy> MatrixLayout( +    "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), +    cl::desc("Sets the default matrix layout"), +    cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", +                          "Use column-major layout"), +               clEnumValN(MatrixLayoutTy::RowMajor, "row-major", +                          "Use row-major layout"))); + +/// Helper function to either return Scope, if it is a subprogram or the +/// attached subprogram for a local scope. +static DISubprogram *getSubprogram(DIScope *Scope) { +  if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) +    return Subprogram; +  return cast<DILocalScope>(Scope)->getSubprogram(); +} +  namespace { -// Given an element poitner \p BasePtr to the start of a (sub) matrix, compute -// the start address of column \p Col with type (\p EltType x \p NumRows) -// assuming \p Stride elements between start two consecutive columns. -// \p Stride must be >= \p NumRows. +// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute +// the start address of vector \p VecIdx with type (\p EltType x \p NumElements) +// assuming \p Stride elements between start two consecutive vectors. +// \p Stride must be >= \p NumElements. +// For column-major matrixes, the function computes the address of a column +// vectors and \p NumElements must be set to the number of elements in a column +// (= number of rows of the matrix). For row-major matrixes, the function +// computes the address of a row vector and \p NumElements must be set to the +// number of elements in a column (= number of columns of the matrix).  // -// Consider a 4x4 matrix like below +// Consider a 4x4 matrix in column-mjaor layout like below  //  //      0       1      2      3  // 0   v_0_0  v_0_1  v_0_2  v_0_3 @@ -62,14 +109,14 @@ namespace {  // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,  // we need a pointer to the first element of the submatrix as base pointer. -// Then we can use computeColumnAddr to compute the addresses for the columns +// Then we can use computeVectorAddr to compute the addresses for the columns  // of the sub-matrix.  // -// Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) +// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)  //           -> just returns Base -// Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) +// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)  //           -> returns Base + (1 * 4) -// Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) +// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)  //           -> returns Base + (2 * 4)  //  // The graphic below illustrates the number of elements in a column (marked @@ -82,30 +129,30 @@ namespace {  //         v_2_0 |v_2_1 |v_2_2 |v_2_3  //         v_3_0 {v_3_1 {v_3_2  v_3_3  // -Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, -                         unsigned NumRows, Type *EltType, +Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, +                         unsigned NumElements, Type *EltType,                           IRBuilder<> &Builder) {    assert((!isa<ConstantInt>(Stride) || -          cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && -         "Stride must be >= the number of rows."); +          cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && +         "Stride must be >= the number of elements in the result vector.");    unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); -  // Compute the start of the column with index Col as Col * Stride. -  Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start"); +  // Compute the start of the vector with index VecIdx as VecIdx * Stride. +  Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); -  // Get pointer to the start of the selected column. Skip GEP creation, -  // if we select column 0. -  if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero()) -    ColumnStart = BasePtr; +  // Get pointer to the start of the selected vector. Skip GEP creation, +  // if we select vector 0. +  if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) +    VecStart = BasePtr;    else -    ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep"); +    VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); -  // Cast elementwise column start pointer to a pointer to a column -  // (EltType x NumRows)*. -  Type *ColumnType = VectorType::get(EltType, NumRows); -  Type *ColumnPtrType = PointerType::get(ColumnType, AS); -  return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast"); +  // Cast elementwise vector start pointer to a pointer to a vector +  // (EltType x NumElements)*. +  auto *VecType = FixedVectorType::get(EltType, NumElements); +  Type *VecPtrType = PointerType::get(VecType, AS); +  return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");  }  /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. @@ -113,15 +160,16 @@ Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,  /// Currently, the lowering for each matrix intrinsic is done as follows:  /// 1. Propagate the shape information from intrinsics to connected  /// instructions. -/// 2. Lower instructions with shape information. +/// 2. Lower instructions with shape information (assuming column-major layout). +///  The lowering works similarly using row-major layout.  ///  2.1. Get column vectors for each argument. If we already lowered the  ///       definition of an argument, use the produced column vectors directly.  ///       If not, split the operand vector containing an embedded matrix into  ///       a set of column vectors, -///  2.2. Lower the instruction in terms of columnwise operations, which yields -///       a set of column vectors containing result matrix. Note that we lower -///       all instructions that have shape information. Besides the intrinsics, -///       this includes stores for example. +///  2.2. Lower the instruction in terms of column major operations, which +///       yields a set of column vectors containing result matrix. Note that we +///       lower all instructions that have shape information. Besides the +///       intrinsics, this includes stores for example.  ///  2.3. Update uses of the lowered instruction. If we have shape information  ///       for a user, there is nothing to do, as we will look up the result  ///       column matrix when lowering the user. For other uses, we embed the @@ -134,42 +182,157 @@ class LowerMatrixIntrinsics {    Function &Func;    const DataLayout &DL;    const TargetTransformInfo &TTI; +  AliasAnalysis &AA; +  DominatorTree &DT; +  LoopInfo &LI; +  OptimizationRemarkEmitter &ORE; + +  /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. +  struct OpInfoTy { +    /// Number of stores emitted to generate this matrix. +    unsigned NumStores = 0; +    /// Number of loads emitted to generate this matrix. +    unsigned NumLoads = 0; +    /// Number of compute operations emitted to generate this matrix. +    unsigned NumComputeOps = 0; + +    OpInfoTy &operator+=(const OpInfoTy &RHS) { +      NumStores += RHS.NumStores; +      NumLoads += RHS.NumLoads; +      NumComputeOps += RHS.NumComputeOps; +      return *this; +    } +  }; + +  /// Wrapper class representing a matrix as a set of vectors, either in row or +  /// column major layout. All vectors must have the same vector type. +  class MatrixTy { +    SmallVector<Value *, 16> Vectors; + +    OpInfoTy OpInfo; -  /// Wrapper class representing a matrix as a set of column vectors. -  /// All column vectors must have the same vector type. -  class ColumnMatrixTy { -    SmallVector<Value *, 16> Columns; +    bool IsColumnMajor = true;    public: -    ColumnMatrixTy() : Columns() {} -    ColumnMatrixTy(ArrayRef<Value *> Cols) -        : Columns(Cols.begin(), Cols.end()) {} +    MatrixTy() +        : Vectors(), +          IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} +    MatrixTy(ArrayRef<Value *> Vectors) +        : Vectors(Vectors.begin(), Vectors.end()), +          IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} +    MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) +        : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { + +      unsigned D = isColumnMajor() ? NumColumns : NumRows; +      for (unsigned J = 0; J < D; ++J) +        addVector(UndefValue::get(FixedVectorType::get( +            EltTy, isColumnMajor() ? NumRows : NumColumns))); +    } + +    Value *getVector(unsigned i) const { return Vectors[i]; } +    Value *getColumn(unsigned i) const { +      assert(isColumnMajor() && "only supported for column-major matrixes"); +      return Vectors[i]; +    } +    Value *getRow(unsigned i) const { +      assert(!isColumnMajor() && "only supported for row-major matrixes"); +      return Vectors[i]; +    } -    Value *getColumn(unsigned i) const { return Columns[i]; } +    void setVector(unsigned i, Value *V) { Vectors[i] = V; } -    void setColumn(unsigned i, Value *V) { Columns[i] = V; } +    Type *getElementType() { return getVectorTy()->getElementType(); } -    size_t getNumColumns() const { return Columns.size(); } -    size_t getNumRows() const { -      assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); -      return cast<VectorType>(Columns[0]->getType())->getNumElements(); +    unsigned getNumVectors() const { +      if (isColumnMajor()) +        return getNumColumns(); +      return getNumRows();      } -    const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; } +    unsigned getNumColumns() const { +      if (isColumnMajor()) +        return Vectors.size(); +      else { +        assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); +        return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); +      } +    } +    unsigned getNumRows() const { +      if (isColumnMajor()) { +        assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); +        return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); +      } else +        return Vectors.size(); +    } -    SmallVectorImpl<Value *> &getColumnVectors() { return Columns; } +    void addVector(Value *V) { Vectors.push_back(V); } +    VectorType *getColumnTy() { +      assert(isColumnMajor() && "only supported for column-major matrixes"); +      return getVectorTy(); +    } -    void addColumn(Value *V) { Columns.push_back(V); } +    VectorType *getVectorTy() { +      return cast<VectorType>(Vectors[0]->getType()); +    }      iterator_range<SmallVector<Value *, 8>::iterator> columns() { -      return make_range(Columns.begin(), Columns.end()); +      assert(isColumnMajor() && +             "columns() only supported for column-major matrixes"); +      return make_range(Vectors.begin(), Vectors.end());      } -    /// Embed the columns of the matrix into a flat vector by concatenating +    iterator_range<SmallVector<Value *, 8>::iterator> vectors() { +      return make_range(Vectors.begin(), Vectors.end()); +    } + +    /// Embed the vectors of the matrix into a flat vector by concatenating      /// them.      Value *embedInVector(IRBuilder<> &Builder) const { -      return Columns.size() == 1 ? Columns[0] -                                 : concatenateVectors(Builder, Columns); +      return Vectors.size() == 1 ? Vectors[0] +                                 : concatenateVectors(Builder, Vectors); +    } + +    MatrixTy &addNumLoads(unsigned N) { +      OpInfo.NumLoads += N; +      return *this; +    } + +    void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } + +    MatrixTy &addNumStores(unsigned N) { +      OpInfo.NumStores += N; +      return *this; +    } + +    MatrixTy &addNumComputeOps(unsigned N) { +      OpInfo.NumComputeOps += N; +      return *this; +    } + +    unsigned getNumStores() const { return OpInfo.NumStores; } +    unsigned getNumLoads() const { return OpInfo.NumLoads; } +    unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } + +    const OpInfoTy &getOpInfo() const { return OpInfo; } + +    bool isColumnMajor() const { return IsColumnMajor; } + +    unsigned getStride() const { +      if (isColumnMajor()) +        return getNumRows(); +      return getNumColumns(); +    } + +    /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the +    /// matrix is column-major, the result vector is extracted from a column +    /// vector, otherwise from a row vector. +    Value *extractVector(unsigned I, unsigned J, unsigned NumElts, +                         IRBuilder<> &Builder) const { +      Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); +      Value *Undef = UndefValue::get(Vec->getType()); +      return Builder.CreateShuffleVector( +          Vec, Undef, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), +          "block");      }    }; @@ -177,12 +340,15 @@ class LowerMatrixIntrinsics {      unsigned NumRows;      unsigned NumColumns; +    bool IsColumnMajor; +      ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) -        : NumRows(NumRows), NumColumns(NumColumns) {} +        : NumRows(NumRows), NumColumns(NumColumns), +          IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}      ShapeInfo(Value *NumRows, Value *NumColumns) -        : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()), -          NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {} +        : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), +                    cast<ConstantInt>(NumColumns)->getZExtValue()) {}      bool operator==(const ShapeInfo &other) {        return NumRows == other.NumRows && NumColumns == other.NumColumns; @@ -195,12 +361,24 @@ class LowerMatrixIntrinsics {        assert(NumRows == 0 || NumColumns != 0);        return NumRows != 0;      } + +    unsigned getStride() const { +      if (IsColumnMajor) +        return NumRows; +      return NumColumns; +    } + +    unsigned getNumVectors() const { +      if (IsColumnMajor) +        return NumColumns; +      return NumRows; +    }    };    /// Maps instructions to their shape information. The shape information    /// describes the shape to be used while lowering. This matches the shape of    /// the result value of the instruction, with the only exceptions being store -  /// instructions and the matrix_columnwise_store intrinsics. For those, the +  /// instructions and the matrix_column_major_store intrinsics. For those, the    /// shape information indicates that those instructions should be lowered    /// using shape information as well.    DenseMap<Value *, ShapeInfo> ShapeMap; @@ -211,31 +389,49 @@ class LowerMatrixIntrinsics {    SmallVector<Instruction *, 16> ToRemove;    /// Map from instructions to their produced column matrix. -  DenseMap<Value *, ColumnMatrixTy> Inst2ColumnMatrix; +  MapVector<Value *, MatrixTy> Inst2ColumnMatrix;  public: -  LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI) -      : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {} +  LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, +                        AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI, +                        OptimizationRemarkEmitter &ORE) +      : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), +        LI(LI), ORE(ORE) {} + +  unsigned getNumOps(Type *VT) { +    assert(isa<VectorType>(VT) && "Expected vector type"); +    return getNumOps(VT->getScalarType(), +                     cast<FixedVectorType>(VT)->getNumElements()); +  } -  /// Return the set of column vectors that a matrix value is lowered to. +  // +  /// Return the estimated number of vector ops required for an operation on +  /// \p VT * N. +  unsigned getNumOps(Type *ST, unsigned N) { +    return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / +                     double(TTI.getRegisterBitWidth(true))); +  } + +  /// Return the set of vectors that a matrix value is lowered to.    /// -  /// If we lowered \p MatrixVal, just return the cache result column matrix. -  /// Otherwie split the flat vector \p MatrixVal containing a matrix with -  /// shape \p SI into column vectors. -  ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, -                           IRBuilder<> Builder) { +  /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise +  /// split the flat vector \p MatrixVal containing a matrix with shape \p SI +  /// into vectors. +  MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, +                     IRBuilder<> &Builder) {      VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());      assert(VType && "MatrixVal must be a vector type"); -    assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && +    assert(cast<FixedVectorType>(VType)->getNumElements() == +               SI.NumRows * SI.NumColumns &&             "The vector size must match the number of matrix elements");      // Check if we lowered MatrixVal using shape information. In that case, -    // return the existing column matrix, if it matches the requested shape +    // return the existing matrix, if it matches the requested shape      // information. If there is a mis-match, embed the result in a flat      // vector and split it later.      auto Found = Inst2ColumnMatrix.find(MatrixVal);      if (Found != Inst2ColumnMatrix.end()) { -      ColumnMatrixTy &M = Found->second; +      MatrixTy &M = Found->second;        // Return the found matrix, if its shape matches the requested shape        // information        if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) @@ -247,10 +443,12 @@ public:      // Otherwise split MatrixVal.      SmallVector<Value *, 16> SplitVecs;      Value *Undef = UndefValue::get(VType); -    for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); -         MaskStart += SI.NumRows) { -      Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0); -      Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); +    for (unsigned MaskStart = 0; +         MaskStart < cast<FixedVectorType>(VType)->getNumElements(); +         MaskStart += SI.getStride()) { +      Value *V = Builder.CreateShuffleVector( +          MatrixVal, Undef, createSequentialMask(MaskStart, SI.getStride(), 0), +          "split");        SplitVecs.push_back(V);      } @@ -308,8 +506,8 @@ public:        switch (II->getIntrinsicID()) {        case Intrinsic::matrix_multiply:        case Intrinsic::matrix_transpose: -      case Intrinsic::matrix_columnwise_load: -      case Intrinsic::matrix_columnwise_store: +      case Intrinsic::matrix_column_major_load: +      case Intrinsic::matrix_column_major_store:          return true;        default:          return false; @@ -348,13 +546,13 @@ public:                                   m_Value(MatrixA), m_Value(M), m_Value(N)))) {          // Flip dimensions.          Propagate = setShapeInfo(Inst, {N, M}); -      } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>( +      } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(                                   m_Value(MatrixA), m_Value(), m_Value(), -                                 m_Value(M), m_Value(N)))) { +                                 m_Value(), m_Value(M), m_Value(N)))) {          Propagate = setShapeInfo(Inst, {N, M}); -      } else if (match(Inst, -                       m_Intrinsic<Intrinsic::matrix_columnwise_load>( -                           m_Value(), m_Value(), m_Value(M), m_Value(N)))) { +      } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( +                                 m_Value(), m_Value(), m_Value(), m_Value(M), +                                 m_Value(N)))) {          Propagate = setShapeInfo(Inst, {M, N});        } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {          auto OpShape = ShapeMap.find(MatrixA); @@ -426,14 +624,14 @@ public:          // Flip dimensions.          if (setShapeInfo(MatrixA, {M, N}))            pushInstruction(MatrixA, WorkList); -      } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>( -                              m_Value(MatrixA), m_Value(), m_Value(), +      } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( +                              m_Value(MatrixA), m_Value(), m_Value(), m_Value(),                                m_Value(M), m_Value(N)))) {          if (setShapeInfo(MatrixA, {M, N})) {            pushInstruction(MatrixA, WorkList);          }        } else if (isa<LoadInst>(V) || -                 match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) { +                 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {          // Nothing to do, no matrix input.        } else if (isa<StoreInst>(V)) {          // Nothing to do.  We forward-propagated to this so we would just @@ -472,8 +670,8 @@ public:            switch (II->getIntrinsicID()) {            case Intrinsic::matrix_multiply:            case Intrinsic::matrix_transpose: -          case Intrinsic::matrix_columnwise_load: -          case Intrinsic::matrix_columnwise_store: +          case Intrinsic::matrix_column_major_load: +          case Intrinsic::matrix_column_major_store:              WorkList.push_back(&Inst);              break;            default: @@ -487,45 +685,57 @@ public:        }      } -    ReversePostOrderTraversal<Function *> RPOT(&Func);      bool Changed = false; -    for (auto *BB : RPOT) { -      for (Instruction &Inst : make_early_inc_range(*BB)) { -        IRBuilder<> Builder(&Inst); - -        if (CallInst *CInst = dyn_cast<CallInst>(&Inst)) -          Changed |= VisitCallInst(CInst); - -        Value *Op1; -        Value *Op2; -        if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst)) -          Changed |= VisitBinaryOperator(BinOp); -        if (match(&Inst, m_Load(m_Value(Op1)))) -          Changed |= VisitLoad(&Inst, Op1, Builder); -        else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) -          Changed |= VisitStore(&Inst, Op1, Op2, Builder); +    SmallVector<CallInst *, 16> MaybeFusableInsts; +    SmallVector<Instruction *, 16> MatrixInsts; + +    // First, collect all instructions with shape information and candidates for +    // fusion (currently only matrix multiplies). +    ReversePostOrderTraversal<Function *> RPOT(&Func); +    for (auto *BB : RPOT) +      for (Instruction &I : *BB) { +        if (ShapeMap.find(&I) == ShapeMap.end()) +          continue; +        if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) +          MaybeFusableInsts.push_back(cast<CallInst>(&I)); +        MatrixInsts.push_back(&I);        } + +    // Second, try to fuse candidates. +    SmallPtrSet<Instruction *, 16> FusedInsts; +    for (CallInst *CI : MaybeFusableInsts) +      LowerMatrixMultiplyFused(CI, FusedInsts); +    Changed = !FusedInsts.empty(); + +    // Third, lower remaining instructions with shape information. +    for (Instruction *Inst : MatrixInsts) { +      if (FusedInsts.count(Inst)) +        continue; + +      IRBuilder<> Builder(Inst); + +      if (CallInst *CInst = dyn_cast<CallInst>(Inst)) +        Changed |= VisitCallInst(CInst); + +      Value *Op1; +      Value *Op2; +      if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) +        Changed |= VisitBinaryOperator(BinOp); +      if (match(Inst, m_Load(m_Value(Op1)))) +        Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); +      else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) +        Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);      } +    RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); +    RemarkGen.emitRemarks(); +      for (Instruction *Inst : reverse(ToRemove))        Inst->eraseFromParent();      return Changed;    } -  LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, -                             IRBuilder<> Builder) { -    unsigned Align = DL.getABITypeAlignment(EltType); -    return Builder.CreateAlignedLoad(ColumnPtr, Align, "col.load"); -  } - -  StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, -                               Type *EltType, IRBuilder<> Builder) { -    unsigned Align = DL.getABITypeAlignment(EltType); -    return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align); -  } - -    /// Turns \p BasePtr into an elementwise pointer to \p EltType.    Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {      unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); @@ -545,11 +755,11 @@ public:      case Intrinsic::matrix_transpose:        LowerTranspose(Inst);        break; -    case Intrinsic::matrix_columnwise_load: -      LowerColumnwiseLoad(Inst); +    case Intrinsic::matrix_column_major_load: +      LowerColumnMajorLoad(Inst);        break; -    case Intrinsic::matrix_columnwise_store: -      LowerColumnwiseStore(Inst); +    case Intrinsic::matrix_column_major_store: +      LowerColumnMajorStore(Inst);        break;      default:        return false; @@ -557,108 +767,200 @@ public:      return true;    } -  void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, -                 ShapeInfo Shape) { -    IRBuilder<> Builder(Inst); -    auto VType = cast<VectorType>(Inst->getType()); +  /// Compute the alignment for a column/row \p Idx with \p Stride between them. +  /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a +  /// ConstantInt, reduce the initial alignment based on the byte offset. For +  /// non-ConstantInt strides, return the common alignment of the initial +  /// alignment and the element size in bytes. +  Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, +                         MaybeAlign A) const { +    Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); +    if (Idx == 0) +      return InitialAlign; + +    TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); +    if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { +      uint64_t StrideInBytes = +          ConstStride->getZExtValue() * ElementSizeInBits / 8; +      return commonAlignment(InitialAlign, Idx * StrideInBytes); +    } +    return commonAlignment(InitialAlign, ElementSizeInBits / 8); +  } + +  /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between +  /// vectors. +  MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, +                      bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { +    auto VType = cast<VectorType>(Ty);      Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); -    ColumnMatrixTy Result; -    // Distance between start of one column and the start of the next -    for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { -      Value *GEP = -          computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows, -                            VType->getElementType(), Builder); -      Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); -      Result.addColumn(Column); +    MatrixTy Result; +    for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { +      Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, +                                     Shape.getStride(), VType->getElementType(), +                                     Builder); +      Value *Vector = Builder.CreateAlignedLoad( +          GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign), +          IsVolatile, "col.load"); + +      Result.addVector(Vector);      } +    return Result.addNumLoads(getNumOps(Result.getVectorTy()) * +                              Result.getNumVectors()); +  } -    finalizeLowering(Inst, Result, Builder); +  /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, +  /// starting at \p MatrixPtr[I][J]. +  MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, +                      ShapeInfo MatrixShape, Value *I, Value *J, +                      ShapeInfo ResultShape, Type *EltTy, +                      IRBuilder<> &Builder) { + +    Value *Offset = Builder.CreateAdd( +        Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); + +    unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); +    Value *EltPtr = +        Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); +    Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); +    auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * +                                                   ResultShape.NumColumns); +    Type *TilePtrTy = PointerType::get(TileTy, AS); +    Value *TilePtr = +        Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); + +    return loadMatrix(TileTy, TilePtr, Align, +                      Builder.getInt64(MatrixShape.getStride()), IsVolatile, +                      ResultShape, Builder); +  } + +  /// Lower a load instruction with shape information. +  void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, +                 bool IsVolatile, ShapeInfo Shape) { +    IRBuilder<> Builder(Inst); +    finalizeLowering(Inst, +                     loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, +                                Shape, Builder), +                     Builder);    } -  /// Lowers llvm.matrix.columnwise.load. +  /// Lowers llvm.matrix.column.major.load.    ///    /// The intrinsic loads a matrix from memory using a stride between columns. -  void LowerColumnwiseLoad(CallInst *Inst) { +  void LowerColumnMajorLoad(CallInst *Inst) { +    assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && +           "Intrinsic only supports column-major layout!");      Value *Ptr = Inst->getArgOperand(0);      Value *Stride = Inst->getArgOperand(1); -    LowerLoad(Inst, Ptr, Stride, -              {Inst->getArgOperand(2), Inst->getArgOperand(3)}); +    LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, +              cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), +              {Inst->getArgOperand(3), Inst->getArgOperand(4)});    } -  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, -                  ShapeInfo Shape) { -    IRBuilder<> Builder(Inst); -    auto VType = cast<VectorType>(Matrix->getType()); +  /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p +  /// MatrixPtr[I][J]. +  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, +                   MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, +                   Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { +    Value *Offset = Builder.CreateAdd( +        Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); + +    unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); +    Value *EltPtr = +        Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); +    Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); +    auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * +                                                   StoreVal.getNumColumns()); +    Type *TilePtrTy = PointerType::get(TileTy, AS); +    Value *TilePtr = +        Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); + +    storeMatrix(TileTy, StoreVal, TilePtr, MAlign, +                Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); +  } + +  /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between +  /// vectors. +  MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, +                       MaybeAlign MAlign, Value *Stride, bool IsVolatile, +                       IRBuilder<> &Builder) { +    auto VType = cast<VectorType>(Ty);      Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); -    auto LM = getMatrix(Matrix, Shape, Builder); -    for (auto C : enumerate(LM.columns())) { -      Value *GEP = -          computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride, -                            Shape.NumRows, VType->getElementType(), Builder); -      createColumnStore(C.value(), GEP, VType->getElementType(), Builder); +    for (auto Vec : enumerate(StoreVal.vectors())) { +      Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), +                                     Stride, StoreVal.getStride(), +                                     VType->getElementType(), Builder); +      Builder.CreateAlignedStore(Vec.value(), GEP, +                                 getAlignForIndex(Vec.index(), Stride, +                                                  VType->getElementType(), +                                                  MAlign), +                                 IsVolatile);      } +    return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * +                                   StoreVal.getNumVectors()); +  } -    ToRemove.push_back(Inst); +  /// Lower a store instruction with shape information. +  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, +                  Value *Stride, bool IsVolatile, ShapeInfo Shape) { +    IRBuilder<> Builder(Inst); +    auto StoreVal = getMatrix(Matrix, Shape, Builder); +    finalizeLowering(Inst, +                     storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, +                                 IsVolatile, Builder), +                     Builder);    } -  /// Lowers llvm.matrix.columnwise.store. +  /// Lowers llvm.matrix.column.major.store.    ///    /// The intrinsic store a matrix back memory using a stride between columns. -  void LowerColumnwiseStore(CallInst *Inst) { +  void LowerColumnMajorStore(CallInst *Inst) { +    assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && +           "Intrinsic only supports column-major layout!");      Value *Matrix = Inst->getArgOperand(0);      Value *Ptr = Inst->getArgOperand(1);      Value *Stride = Inst->getArgOperand(2); -    LowerStore(Inst, Matrix, Ptr, Stride, -               {Inst->getArgOperand(3), Inst->getArgOperand(4)}); -  } - -  /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from -  /// the matrix \p LM represented as a vector of column vectors. -  Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, -                       unsigned NumElts, IRBuilder<> Builder) { -    Value *Col = LM.getColumn(J); -    Value *Undef = UndefValue::get(Col->getType()); -    Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); -    return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); +    LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, +               cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), +               {Inst->getArgOperand(4), Inst->getArgOperand(5)});    }    // Set elements I..I+NumElts-1 to Block    Value *insertVector(Value *Col, unsigned I, Value *Block, -                      IRBuilder<> Builder) { +                      IRBuilder<> &Builder) {      // First, bring Block to the same size as Col      unsigned BlockNumElts = -        cast<VectorType>(Block->getType())->getNumElements(); -    unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); +        cast<FixedVectorType>(Block->getType())->getNumElements(); +    unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();      assert(NumElts >= BlockNumElts && "Too few elements for current block"); -    Value *ExtendMask = -        createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts);      Value *Undef = UndefValue::get(Block->getType()); -    Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); +    Block = Builder.CreateShuffleVector( +        Block, Undef, +        createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));      // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,      // 8, 4, 5, 6 -    SmallVector<Constant *, 16> Mask; +    SmallVector<int, 16> Mask;      unsigned i;      for (i = 0; i < I; i++) -      Mask.push_back(Builder.getInt32(i)); +      Mask.push_back(i); -    unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); +    unsigned VecNumElts = +        cast<FixedVectorType>(Col->getType())->getNumElements();      for (; i < I + BlockNumElts; i++) -      Mask.push_back(Builder.getInt32(i - I + VecNumElts)); +      Mask.push_back(i - I + VecNumElts);      for (; i < VecNumElts; i++) -      Mask.push_back(Builder.getInt32(i)); - -    Value *MaskVal = ConstantVector::get(Mask); +      Mask.push_back(i); -    return Builder.CreateShuffleVector(Col, Block, MaskVal); +    return Builder.CreateShuffleVector(Col, Block, Mask);    }    Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, -                      IRBuilder<> &Builder, bool AllowContraction) { - +                      IRBuilder<> &Builder, bool AllowContraction, +                      unsigned &NumComputeOps) { +    NumComputeOps += getNumOps(A->getType());      if (!Sum)        return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); @@ -666,14 +968,16 @@ public:        if (AllowContraction) {          // Use fmuladd for floating point operations and let the backend decide          // if that's profitable. -        Value *FMulAdd = Intrinsic::getDeclaration( +        Function *FMulAdd = Intrinsic::getDeclaration(              Func.getParent(), Intrinsic::fmuladd, A->getType());          return Builder.CreateCall(FMulAdd, {A, B, Sum});        } +      NumComputeOps += getNumOps(A->getType());        Value *Mul = Builder.CreateFMul(A, B);        return Builder.CreateFAdd(Sum, Mul);      } +    NumComputeOps += getNumOps(A->getType());      Value *Mul = Builder.CreateMul(A, B);      return Builder.CreateAdd(Sum, Mul);    } @@ -683,7 +987,7 @@ public:    /// cached value when they are lowered. For other users, \p Matrix is    /// flattened and the uses are updated to use it. Also marks \p Inst for    /// deletion. -  void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix, +  void finalizeLowering(Instruction *Inst, MatrixTy Matrix,                          IRBuilder<> &Builder) {      Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); @@ -699,6 +1003,294 @@ public:      }    } +  /// Compute \p Result += \p A * \p B for input matrices with left-associating +  /// addition. +  void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, +                          const MatrixTy &B, bool AllowContraction, +                          IRBuilder<> &Builder, bool isTiled) { +    const unsigned VF = std::max<unsigned>( +        TTI.getRegisterBitWidth(true) / +            Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), +        1U); +    unsigned R = Result.getNumRows(); +    unsigned C = Result.getNumColumns(); +    unsigned M = A.getNumColumns(); + +    bool IsFP = Result.getElementType()->isFloatingPointTy(); +    assert(A.isColumnMajor() == B.isColumnMajor() && +           Result.isColumnMajor() == A.isColumnMajor() && +           "operands must agree on matrix layout"); +    unsigned NumComputeOps = 0; +    if (A.isColumnMajor()) { +      // Multiply columns from the first operand with scalars from the second +      // operand. Then move along the K axes and accumulate the columns.  With +      // this the adds can be vectorized without reassociation. +      for (unsigned J = 0; J < C; ++J) { +        unsigned BlockSize = VF; +        // If Result is zero, we don't need to accumulate in the K==0 iteration. +        bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); + +        for (unsigned I = 0; I < R; I += BlockSize) { +          // Gradually lower the vectorization factor to cover the remainder. +          while (I + BlockSize > R) +            BlockSize /= 2; + +          Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder) +                               : nullptr; +          for (unsigned K = 0; K < M; ++K) { +            Value *L = A.extractVector(I, K, BlockSize, Builder); +            Value *RH = Builder.CreateExtractElement(B.getColumn(J), K); +            Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); +            Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, +                               Result.getElementType()->isFloatingPointTy(), +                               Builder, AllowContraction, NumComputeOps); +          } +          Result.setVector(J, +                           insertVector(Result.getVector(J), I, Sum, Builder)); +        } +      } +    } else { +      // Multiply rows from the second operand with scalars from the first +      // operand. Then move along the K axes and accumulate the rows.  With this +      // the adds can be vectorized without reassociation. +      for (unsigned I = 0; I < R; ++I) { +        unsigned BlockSize = VF; +        bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); +        for (unsigned J = 0; J < C; J += BlockSize) { +          // Gradually lower the vectorization factor to cover the remainder. +          while (J + BlockSize > C) +            BlockSize /= 2; + +          Value *Sum = nullptr; +          for (unsigned K = 0; K < M; ++K) { +            Value *R = B.extractVector(K, J, BlockSize, Builder); +            Value *LH = Builder.CreateExtractElement(A.getVector(I), K); +            Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); +            Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, +                               IsFP, Builder, AllowContraction, NumComputeOps); +          } +          Result.setVector(I, +                           insertVector(Result.getVector(I), J, Sum, Builder)); +        } +      } +    } +    Result.addNumComputeOps(NumComputeOps); +  } + +  /// Ensure that the memory in \p Load does not alias \p Store by potentially +  /// copying it to a new location.  This new or otherwise the original location +  /// is returned. +  Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, +                               CallInst *MatMul) { +    MemoryLocation StoreLoc = MemoryLocation::get(Store); +    MemoryLocation LoadLoc = MemoryLocation::get(Load); + +    AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc); + +    // If we can statically determine noalias we're good. +    if (!LdAliased) +      return Load->getPointerOperand(); + +    // Create code to check if the memory locations of the Load and Store +    // overlap and if they do, copy Load's operand to a new buffer. + +    // First, create  new blocks for 2n part of the check and the copy. +    BasicBlock *Check0 = MatMul->getParent(); +    // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a +    // DT. Manually collect dominator tree updates, to avoid unnecessary work, +    // as we adjust Check0 and Check1's branches. +    SmallVector<DominatorTree::UpdateType, 4> DTUpdates; +    for (BasicBlock *Succ : successors(Check0)) +      DTUpdates.push_back({DT.Delete, Check0, Succ}); + +    BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, +                                    nullptr, "alias_cont"); +    BasicBlock *Copy = +        SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy"); +    BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, +                                    nullptr, "no_alias"); + +    // Check if the loaded memory location begins before the end of the store +    // location. If the condition holds, they might overlap, otherwise they are +    // guaranteed to not overlap. +    IRBuilder<> Builder(MatMul); +    Check0->getTerminator()->eraseFromParent(); +    Builder.SetInsertPoint(Check0); +    Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); +    Value *StoreBegin = Builder.CreatePtrToInt( +        const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); +    Value *StoreEnd = Builder.CreateAdd( +        StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), +        "store.end", true, true); +    Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), +                                              IntPtrTy, "load.begin"); +    Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, +                         Fusion); + +    // Check if the store begins before the end of the load location. If the +    // condition holds, they alias, otherwise they are guaranteed to not +    // overlap. +    Check1->getTerminator()->eraseFromParent(); +    Builder.SetInsertPoint(Check1, Check1->begin()); +    Value *LoadEnd = Builder.CreateAdd( +        LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), +        "load.end", true, true); +    Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, +                         Fusion); + +    // Copy load operand to new alloca. +    Builder.SetInsertPoint(Copy, Copy->begin()); +    AllocaInst *NewLd = +        Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace()); +    Builder.CreateMemCpy(NewLd, NewLd->getAlign(), +                         Load->getPointerOperand(), Load->getAlign(), +                         LoadLoc.Size.getValue()); +    Builder.SetInsertPoint(Fusion, Fusion->begin()); +    PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); +    PHI->addIncoming(Load->getPointerOperand(), Check0); +    PHI->addIncoming(Load->getPointerOperand(), Check1); +    PHI->addIncoming(NewLd, Copy); + +    // Adjust DT. +    DTUpdates.push_back({DT.Insert, Check0, Check1}); +    DTUpdates.push_back({DT.Insert, Check0, Fusion}); +    DTUpdates.push_back({DT.Insert, Check1, Copy}); +    DTUpdates.push_back({DT.Insert, Check1, Fusion}); +    DT.applyUpdates(DTUpdates); +    return PHI; +  } + +  bool isFusionProfitable(CallInst *MatMul) { +    if (ForceFusion) +      return true; + +    ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); +    ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); + +    const unsigned R = LShape.NumRows; +    const unsigned C = RShape.NumColumns; +    const unsigned M = LShape.NumColumns; +    auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); + +    const unsigned VF = +        std::max<unsigned>(TTI.getRegisterBitWidth(true) / +                               EltType->getPrimitiveSizeInBits().getFixedSize(), +                           1U); + +    // Cost model for tiling +    // +    // For tiling to be beneficial, we need reuse either along the R or +    // the C axis.  We vectorize along the R axis so that means at least +    // 3 elements. +    // TODO: Also consider cost of copying if operands alias. +    if (R <= VF && C == 1) +      return false; +    // Then we need enough elements to exceed the number of vector +    // registers we have.  Note that this is an oversimplification since +    // fusing also takes some extra loads which may exceed the number of +    // reloads necessary. +    unsigned Op0Regs = (R + VF - 1) / VF * M; +    unsigned Op1Regs = (M + VF - 1) / VF * C; +    return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true); +  } + +  MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { +    MatrixTy Res; +    auto *ColumType = FixedVectorType::get(EltType, R); +    for (unsigned I = 0; I < C; ++I) +      Res.addVector(ConstantAggregateZero::get(ColumType)); +    return Res; +  } + +  void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, +                      StoreInst *Store, +                      SmallPtrSetImpl<Instruction *> &FusedInsts) { +    assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && +           "Tiling only supported for column-major matrixes at the moment!"); +    if (!isFusionProfitable(MatMul)) +      return; + +    ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); +    ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); + +    const unsigned R = LShape.NumRows; +    const unsigned C = RShape.NumColumns; +    const unsigned M = LShape.NumColumns; +    auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); + +    Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); +    Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); +    Value *CPtr = Store->getPointerOperand(); + +    bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && +                                                  MatMul->hasAllowContract()); +    IRBuilder<> Builder(Store); +    for (unsigned J = 0; J < C; J += TileSize) +      for (unsigned I = 0; I < R; I += TileSize) { +        const unsigned TileR = std::min(R - I, unsigned(TileSize)); +        const unsigned TileC = std::min(C - J, unsigned(TileSize)); +        MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); + +        for (unsigned K = 0; K < M; K += TileSize) { +          const unsigned TileM = std::min(M - K, unsigned(TileSize)); +          MatrixTy A = +              loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), +                         LShape, Builder.getInt64(I), Builder.getInt64(K), +                         {TileR, TileM}, EltType, Builder); +          MatrixTy B = +              loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), +                         RShape, Builder.getInt64(K), Builder.getInt64(J), +                         {TileM, TileC}, EltType, Builder); +          emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); +        } +        storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, +                    Builder.getInt64(I), Builder.getInt64(J), EltType, Builder); +      } + +    // Mark eliminated instructions as fused and remove them. +    FusedInsts.insert(Store); +    FusedInsts.insert(MatMul); +    Store->eraseFromParent(); +    MatMul->eraseFromParent(); +    if (LoadOp0->hasNUses(0)) { +      FusedInsts.insert(LoadOp0); +      LoadOp0->eraseFromParent(); +    } +    if (LoadOp1->hasNUses(0)) { +      FusedInsts.insert(LoadOp1); +      LoadOp1->eraseFromParent(); +    } +  } + +  /// Try to lower matrix multiply chains by fusing operations. +  /// +  /// Currently we only lower {ld, ld} -> matmul -> st chains. +  // +  /// No need to return a MatrixTy object for the result of the operation, since +  /// the single store user will be lowered as part of this. Instructions that +  /// are completely eliminated by fusion are added to \p FusedInsts. +  void LowerMatrixMultiplyFused(CallInst *MatMul, +                                SmallPtrSetImpl<Instruction *> &FusedInsts) { +    if (!FuseMatrix || !MatMul->hasOneUse() || +        MatrixLayout != MatrixLayoutTy::ColumnMajor) +      return; + +    auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0)); +    auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1)); +    auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); +    if (LoadOp0 && LoadOp1 && Store) { +      // The store address must dominate the MatMul instruction, otherwise +      // we create invalid IR. +      // FIXME: See if we can hoist the store address computation. +      auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1)); +      if (AddrI && (!DT.dominates(AddrI, MatMul))) +        return; + +      emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); +      return; +    } +  } +    /// Lowers llvm.matrix.multiply.    void LowerMultiply(CallInst *MatMul) {      IRBuilder<> Builder(MatMul); @@ -706,97 +1298,80 @@ public:      ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));      ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); -    const ColumnMatrixTy &Lhs = -        getMatrix(MatMul->getArgOperand(0), LShape, Builder); -    const ColumnMatrixTy &Rhs = -        getMatrix(MatMul->getArgOperand(1), RShape, Builder); +    const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); +    const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);      const unsigned R = LShape.NumRows; -    const unsigned M = LShape.NumColumns;      const unsigned C = RShape.NumColumns; -    assert(M == RShape.NumRows); +    assert(LShape.NumColumns == RShape.NumRows);      // Initialize the output -    ColumnMatrixTy Result; -    for (unsigned J = 0; J < C; ++J) -      Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); - -    const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / -                                     EltType->getPrimitiveSizeInBits(), -                                 uint64_t(1)); +    MatrixTy Result(R, C, EltType);      bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&                                                    MatMul->hasAllowContract()); -    // Multiply columns from the first operand with scalars from the second -    // operand.  Then move along the K axes and accumulate the columns.  With -    // this the adds can be vectorized without reassociation. -    for (unsigned J = 0; J < C; ++J) { -      unsigned BlockSize = VF; -      for (unsigned I = 0; I < R; I += BlockSize) { -        // Gradually lower the vectorization factor to cover the remainder. -        while (I + BlockSize > R) -          BlockSize /= 2; - -        Value *Sum = nullptr; -        for (unsigned K = 0; K < M; ++K) { -          Value *L = extractVector(Lhs, I, K, BlockSize, Builder); -          Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); -          Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); -          Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), -                             Builder, AllowContract); -        } -        Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); -      } -    } +    emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false);      finalizeLowering(MatMul, Result, Builder);    }    /// Lowers llvm.matrix.transpose.    void LowerTranspose(CallInst *Inst) { -    ColumnMatrixTy Result; +    MatrixTy Result;      IRBuilder<> Builder(Inst);      Value *InputVal = Inst->getArgOperand(0);      VectorType *VectorTy = cast<VectorType>(InputVal->getType());      ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); -    ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); - -    for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { -      // Build a single column vector for this row. First initialize it. -      Value *ResultColumn = UndefValue::get( -          VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); - -      // Go through the elements of this row and insert it into the resulting -      // column vector. -      for (auto C : enumerate(InputMatrix.columns())) { -        Value *Elt = Builder.CreateExtractElement(C.value(), Row); -        // We insert at index Column since that is the row index after the -        // transpose. -        ResultColumn = -            Builder.CreateInsertElement(ResultColumn, Elt, C.index()); +    MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); + +    const unsigned NewNumVecs = +        InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; +    const unsigned NewNumElts = +        InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; + +    for (unsigned I = 0; I < NewNumVecs; ++I) { +      // Build a single result vector. First initialize it. +      Value *ResultVector = UndefValue::get( +          FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); +      // Go through the old elements and insert it into the resulting vector. +      for (auto J : enumerate(InputMatrix.vectors())) { +        Value *Elt = Builder.CreateExtractElement(J.value(), I); +        // Row and column indices are transposed. +        ResultVector = +            Builder.CreateInsertElement(ResultVector, Elt, J.index());        } -      Result.addColumn(ResultColumn); +      Result.addVector(ResultVector);      } -    finalizeLowering(Inst, Result, Builder); +    // TODO: Improve estimate of operations needed for transposes. Currently we +    // just count the insertelement/extractelement instructions, but do not +    // account for later simplifications/combines. +    finalizeLowering( +        Inst, +        Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), +        Builder);    }    /// Lower load instructions, if shape information is available. -  bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) { +  bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {      auto I = ShapeMap.find(Inst);      if (I == ShapeMap.end())        return false; -    LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second); +    LowerLoad(Inst, Ptr, Inst->getAlign(), +              Builder.getInt64(I->second.getStride()), Inst->isVolatile(), +              I->second);      return true;    } -  bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, +  bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,                    IRBuilder<> &Builder) {      auto I = ShapeMap.find(StoredVal);      if (I == ShapeMap.end())        return false; -    LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second); +    LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), +               Builder.getInt64(I->second.getStride()), Inst->isVolatile(), +               I->second);      return true;    } @@ -812,12 +1387,15 @@ public:      IRBuilder<> Builder(Inst);      ShapeInfo &Shape = I->second; -    ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); -    ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); +    MatrixTy Result; +    MatrixTy A = getMatrix(Lhs, Shape, Builder); +    MatrixTy B = getMatrix(Rhs, Shape, Builder); +    assert(A.isColumnMajor() == B.isColumnMajor() && +           Result.isColumnMajor() == A.isColumnMajor() && +           "operands must agree on matrix layout"); -    // Add each column and store the result back into the opmapping -    ColumnMatrixTy Result; -    auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { +    // Helper to perform binary op on vectors. +    auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {        switch (Inst->getOpcode()) {        case Instruction::Add:          return Builder.CreateAdd(LHS, RHS); @@ -835,20 +1413,462 @@ public:          llvm_unreachable("Unsupported binary operator for matrix");        }      }; -    for (unsigned C = 0; C < Shape.NumColumns; ++C) -      Result.addColumn( -          BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C))); -    finalizeLowering(Inst, Result, Builder); +    for (unsigned I = 0; I < Shape.getNumVectors(); ++I) +      Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); + +    finalizeLowering(Inst, +                     Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * +                                             Result.getNumVectors()), +                     Builder);      return true;    } + +  /// Helper to linearize a matrix expression tree into a string. Currently +  /// matrix expressions are linarized by starting at an expression leaf and +  /// linearizing bottom up. +  struct ExprLinearizer { +    unsigned LengthToBreak = 100; +    std::string Str; +    raw_string_ostream Stream; +    unsigned LineLength = 0; +    const DataLayout &DL; + +    /// Mapping from instructions to matrixes. It is used to identify +    /// matrix instructions. +    const MapVector<Value *, MatrixTy> &Inst2Matrix; + +    /// Mapping from values to the leaves of all expressions that the value is +    /// part of. +    const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; + +    /// Set of matrix expressions in the scope of a given DISubprogram. +    const SmallSetVector<Value *, 32> &ExprsInSubprogram; + +    /// Leaf node of the expression to linearize. +    Value *Leaf; + +    /// Used to keep track of sub-expressions that get reused while linearizing +    /// the expression. Re-used sub-expressions are marked as (reused). +    SmallPtrSet<Value *, 8> ReusedExprs; + +    ExprLinearizer(const DataLayout &DL, +                   const MapVector<Value *, MatrixTy> &Inst2Matrix, +                   const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, +                   const SmallSetVector<Value *, 32> &ExprsInSubprogram, +                   Value *Leaf) +        : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), +          ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} + +    void indent(unsigned N) { +      LineLength += N; +      for (unsigned i = 0; i < N; i++) +        Stream << " "; +    } + +    void lineBreak() { +      Stream << "\n"; +      LineLength = 0; +    } + +    void maybeIndent(unsigned Indent) { +      if (LineLength >= LengthToBreak) +        lineBreak(); + +      if (LineLength == 0) +        indent(Indent); +    } + +    void write(StringRef S) { +      LineLength += S.size(); +      Stream << S; +    } + +    Value *getUnderlyingObjectThroughLoads(Value *V) { +      if (Value *Ptr = getPointerOperand(V)) +        return getUnderlyingObjectThroughLoads(Ptr); +      else if (V->getType()->isPointerTy()) +        return GetUnderlyingObject(V, DL); +      return V; +    } + +    /// Returns true if \p V is a matrix value in the given subprogram. +    bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } + +    /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to +    /// \p SS. +    void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { +      auto M = Inst2Matrix.find(V); +      if (M == Inst2Matrix.end()) +        SS << "unknown"; +      else { +        SS << M->second.getNumRows(); +        SS << "x"; +        SS << M->second.getNumColumns(); +      } +    } + +    /// Write the called function name. Handles calls to llvm.matrix.* +    /// specially: we write the name, followed by the dimensions of the input +    /// matrixes, followed by the scalar type name. +    void writeFnName(CallInst *CI) { +      if (!CI->getCalledFunction()) +        write("<no called fn>"); +      else { +        StringRef Name = CI->getCalledFunction()->getName(); +        if (!Name.startswith("llvm.matrix")) { +          write(Name); +          return; +        } +        IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); +        write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) +                  .drop_front(StringRef("llvm.matrix.").size())); +        write("."); +        std::string Tmp = ""; +        raw_string_ostream SS(Tmp); + +        switch (II->getIntrinsicID()) { +        case Intrinsic::matrix_multiply: +          prettyPrintMatrixType(II->getOperand(0), SS); +          SS << "."; +          prettyPrintMatrixType(II->getOperand(1), SS); +          SS << "." << *II->getType()->getScalarType(); +          break; +        case Intrinsic::matrix_transpose: +          prettyPrintMatrixType(II->getOperand(0), SS); +          SS << "." << *II->getType()->getScalarType(); +          break; +        case Intrinsic::matrix_column_major_load: +          prettyPrintMatrixType(II, SS); +          SS << "." << *II->getType()->getScalarType(); +          break; +        case Intrinsic::matrix_column_major_store: +          prettyPrintMatrixType(II->getOperand(0), SS); +          SS << "." << *II->getOperand(0)->getType()->getScalarType(); +          break; +        default: +          llvm_unreachable("Unhandled case"); +        } +        SS.flush(); +        write(Tmp); +      } +    } + +    unsigned getNumShapeArgs(CallInst *CI) const { +      if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { +        switch (II->getIntrinsicID()) { +        case Intrinsic::matrix_multiply: +          return 3; +        case Intrinsic::matrix_transpose: +          return 2; +        case Intrinsic::matrix_column_major_load: +        case Intrinsic::matrix_column_major_store: +          return 3; +        default: +          return 0; +        } +      } +      return 0; +    } + +    /// Special printing for values: for pointers, we print if they refer to an +    /// (function) external address or a stack address, for other values we +    /// either print the constant or "scalar"/"matrix" for other values. +    void write(Value *V) { +      V = getUnderlyingObjectThroughLoads(V); +      if (V->getType()->isPointerTy()) { +        if (isa<AllocaInst>(V)) { +          Stream << "stack addr"; +          LineLength += StringRef("stack addr").size(); +        } else { +          Stream << "addr"; +          LineLength += StringRef("addr").size(); +        } +        if (!V->getName().empty()) { +          Stream << " %" << V->getName() << ""; +          LineLength += V->getName().size() + 2; +        } +        return; +      } + +      std::string Tmp; +      raw_string_ostream TmpStream(Tmp); + +      if (auto *CI = dyn_cast<ConstantInt>(V)) +        TmpStream << CI->getValue(); +      else if (isa<Constant>(V)) +        TmpStream << "constant"; +      else { +        if (isMatrix(V)) +          TmpStream << "matrix"; +        else +          TmpStream << "scalar"; +      } +      TmpStream.flush(); +      Tmp = std::string(StringRef(Tmp).trim()); +      LineLength += Tmp.size(); +      Stream << Tmp; +    } + +    /// Linearize expression \p Expr starting at an indentation of \p Indent. +    /// Expressions that are re-used multiple times are prefixed with (reused) +    /// at the re-used root instruction. +    void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, +                       bool ParentShared) { +      auto *I = cast<Instruction>(Expr); +      maybeIndent(Indent); +      SmallVector<Value *, 8> Ops; + +      // Is Expr shared with other expression leaves? +      bool ExprShared = false; + +      // Deal with shared subtrees. Mark them as shared, if required. +      if (!ParentShared) { +        auto SI = Shared.find(Expr); +        assert(SI != Shared.end() && SI->second.count(Leaf)); + +        for (Value *S : SI->second) { +          if (S == Leaf) +            continue; +          DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); +          write("shared with remark at line " + std::to_string(DL.getLine()) + +                " column " + std::to_string(DL.getCol()) + " ("); +        } +        ExprShared = SI->second.size() > 1; +      } + +      bool Reused = !ReusedExprs.insert(Expr).second; +      if (Reused && !ParentReused) +        write("(reused) "); + +      if (auto *CI = dyn_cast<CallInst>(I)) { +        writeFnName(CI); + +        Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); +      } else if (isa<BitCastInst>(Expr)) { +        // Special case bitcasts, which are used to materialize matrixes from +        // non-matrix ops. +        write("matrix"); +        return; +      } else { +        Ops.append(I->value_op_begin(), I->value_op_end()); +        write(std::string(I->getOpcodeName())); +      } + +      write(std::string("(")); + +      unsigned NumOpsToBreak = 1; +      if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) +        NumOpsToBreak = 2; + +      for (Value *Op : Ops) { +        if (Ops.size() > NumOpsToBreak) +          lineBreak(); + +        maybeIndent(Indent + 1); +        if (isMatrix(Op)) +          linearizeExpr(Op, Indent + 1, Reused, ExprShared); +        else +          write(Op); +        if (Op != Ops.back()) +          write(", "); +      } + +      write(")"); +    } + +    const std::string &getResult() { +      Stream.flush(); +      return Str; +    } +  }; + +  /// Generate remarks for matrix operations in a function. To generate remarks +  /// for matrix expressions, the following approach is used: +  /// 1. Use the inlined-at debug information to group matrix operations to the +  ///    DISubprograms they are contained in. +  /// 2. Collect leaves of matrix expressions (done in +  ///    RemarkGenerator::getExpressionLeaves) for each subprogram - expression +  //     mapping.  Leaves are lowered matrix instructions without other matrix +  //     users (like stores) in the current subprogram. +  /// 3. For each leaf, create a remark containing a linearizied version of the +  ///    matrix expression. The expression is linearized by a recursive +  ///    bottom-up traversal of the matrix operands, starting at a leaf. Note +  ///    that multiple leaves can share sub-expressions. Shared subexpressions +  ///    are explicitly marked as shared(). +  struct RemarkGenerator { +    const MapVector<Value *, MatrixTy> &Inst2Matrix; +    OptimizationRemarkEmitter &ORE; +    Function &Func; +    const DataLayout &DL; + +    RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, +                    OptimizationRemarkEmitter &ORE, Function &Func) +        : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), +          DL(Func.getParent()->getDataLayout()) {} + +    /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are +    /// instructions in Inst2Matrix returning void or without any users in +    /// \p ExprsInSubprogram. Currently that should only include stores. +    SmallVector<Value *, 4> +    getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { +      SmallVector<Value *, 4> Leaves; +      for (auto *Expr : ExprsInSubprogram) +        if (Expr->getType()->isVoidTy() || +            !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { +              return ExprsInSubprogram.count(U); +            })) +          Leaves.push_back(Expr); +      return Leaves; +    } + +    /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf +    /// to all visited expressions in \p Shared. Limit the matrix operations to +    /// the ones in \p ExprsInSubprogram. +    void collectSharedInfo(Value *Leaf, Value *V, +                           const SmallSetVector<Value *, 32> &ExprsInSubprogram, +                           DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { + +      if (!ExprsInSubprogram.count(V)) +        return; + +      auto I = Shared.insert({V, {}}); +      I.first->second.insert(Leaf); + +      for (Value *Op : cast<Instruction>(V)->operand_values()) +        collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); +      return; +    } + +    /// Calculate the number of exclusive and shared op counts for expression +    /// starting at \p V. Expressions used multiple times are counted once. +    /// Limit the matrix operations to the ones in \p ExprsInSubprogram. +    std::pair<OpInfoTy, OpInfoTy> +    sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, +               const SmallSetVector<Value *, 32> &ExprsInSubprogram, +               DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { +      if (!ExprsInSubprogram.count(Root)) +        return {}; + +      // Already counted this expression. Stop. +      if (!ReusedExprs.insert(Root).second) +        return {}; + +      OpInfoTy SharedCount; +      OpInfoTy Count; + +      auto I = Shared.find(Root); +      auto CM = Inst2Matrix.find(Root); +      if (I->second.size() == 1) +        Count = CM->second.getOpInfo(); +      else +        SharedCount = CM->second.getOpInfo(); + +      for (Value *Op : cast<Instruction>(Root)->operand_values()) { +        auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); +        Count += C.first; +        SharedCount += C.second; +      } +      return {Count, SharedCount}; +    } + +    void emitRemarks() { +      if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) +        return; + +      // Map matrix operations to their containting subprograms, by traversing +      // the inlinedAt chain. If the function does not have a DISubprogram, we +      // only map them to the containing function. +      MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; +      for (auto &KV : Inst2Matrix) { +        if (Func.getSubprogram()) { +          auto *I = cast<Instruction>(KV.first); +          DILocation *Context = I->getDebugLoc(); +          while (Context) { +            auto I = +                Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); +            I.first->second.push_back(KV.first); +            Context = DebugLoc(Context).getInlinedAt(); +          } +        } else { +          auto I = Subprog2Exprs.insert({nullptr, {}}); +          I.first->second.push_back(KV.first); +        } +      } +      for (auto &KV : Subprog2Exprs) { +        SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), +                                                      KV.second.end()); +        auto Leaves = getExpressionLeaves(ExprsInSubprogram); + +        DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; +        for (Value *Leaf : Leaves) +          collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); + +        // Generate remarks for each leaf. +        for (auto *L : Leaves) { + +          DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); +          DILocation *Context = cast<Instruction>(L)->getDebugLoc(); +          while (Context) { +            if (getSubprogram(Context->getScope()) == KV.first) { +              Loc = Context; +              break; +            } +            Context = DebugLoc(Context).getInlinedAt(); +          } + +          SmallPtrSet<Value *, 8> ReusedExprs; +          OpInfoTy Counts, SharedCounts; +          std::tie(Counts, SharedCounts) = +              sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); + +          OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, +                                 cast<Instruction>(L)->getParent()); + +          Rem << "Lowered with "; +          Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " +              << ore::NV("NumLoads", Counts.NumLoads) << " loads, " +              << ore::NV("NumComputeOps", Counts.NumComputeOps) +              << " compute ops"; + +          if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || +              SharedCounts.NumComputeOps > 0) { +            Rem << ",\nadditionally " +                << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " +                << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " +                << ore::NV("NumFPOps", SharedCounts.NumComputeOps) +                << " compute ops" +                << " are shared with other expressions"; +          } + +          Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); +          ORE.emit(Rem); +        } +      } +    } + +    std::string +    linearize(Value *L, +              const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, +              const SmallSetVector<Value *, 32> &ExprsInSubprogram, +              const DataLayout &DL) { +      ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); +      Lin.linearizeExpr(L, 0, false, false); +      return Lin.getResult(); +    } +  };  };  } // namespace  PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,                                                   FunctionAnalysisManager &AM) {    auto &TTI = AM.getResult<TargetIRAnalysis>(F); -  LowerMatrixIntrinsics LMT(F, TTI); +  auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); +  auto &AA = AM.getResult<AAManager>(F); +  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); +  auto &LI = AM.getResult<LoopAnalysis>(F); + +  LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);    if (LMT.Visit()) {      PreservedAnalyses PA;      PA.preserveSet<CFGAnalyses>(); @@ -869,15 +1889,24 @@ public:    }    bool runOnFunction(Function &F) override { -    auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); -    LowerMatrixIntrinsics LMT(F, *TTI); +    auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); +    auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); +    auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); +    auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +    auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +    LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);      bool C = LMT.Visit();      return C;    }    void getAnalysisUsage(AnalysisUsage &AU) const override {      AU.addRequired<TargetTransformInfoWrapperPass>(); -    AU.setPreservesCFG(); +    AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); +    AU.addRequired<AAResultsWrapperPass>(); +    AU.addRequired<DominatorTreeWrapperPass>(); +    AU.addPreserved<DominatorTreeWrapperPass>(); +    AU.addRequired<LoopInfoWrapperPass>(); +    AU.addPreserved<LoopInfoWrapperPass>();    }  };  } // namespace @@ -886,6 +1915,10 @@ static const char pass_name[] = "Lower the matrix intrinsics";  char LowerMatrixIntrinsicsLegacyPass::ID = 0;  INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,                        false, false) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)  INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,                      false, false) | 
