aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp46
1 files changed, 10 insertions, 36 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index f46ea6a20afa..72b9db1e73d7 100644
--- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -19,6 +19,7 @@
#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/DomTreeUpdater.h"
@@ -36,12 +37,9 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MatrixBuilder.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/MatrixUtils.h"
@@ -180,7 +178,6 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
assert((!isa<ConstantInt>(Stride) ||
cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
"Stride must be >= the number of elements in the result vector.");
- unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
// Compute the start of the vector with index VecIdx as VecIdx * Stride.
Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
@@ -192,11 +189,7 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
else
VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
- // Cast elementwise vector start pointer to a pointer to a vector
- // (EltType x NumElements)*.
- auto *VecType = FixedVectorType::get(EltType, NumElements);
- Type *VecPtrType = PointerType::get(VecType, AS);
- return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
+ return VecStart;
}
/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
@@ -1063,13 +1056,6 @@ public:
return Changed;
}
- /// Turns \p BasePtr into an elementwise pointer to \p EltType.
- Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
- unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
- Type *EltPtrType = PointerType::get(EltType, AS);
- return Builder.CreatePointerCast(BasePtr, EltPtrType);
- }
-
/// Replace intrinsic calls
bool VisitCallInst(CallInst *Inst) {
if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
@@ -1121,7 +1107,7 @@ public:
auto *VType = cast<VectorType>(Ty);
Type *EltTy = VType->getElementType();
Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
- Value *EltPtr = createElementPtr(Ptr, EltTy, Builder);
+ Value *EltPtr = Ptr;
MatrixTy Result;
for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
Value *GEP = computeVectorAddr(
@@ -1147,17 +1133,11 @@ public:
Value *Offset = Builder.CreateAdd(
Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
- unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
- Value *EltPtr =
- Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
- Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+ Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
ResultShape.NumColumns);
- Type *TilePtrTy = PointerType::get(TileTy, AS);
- Value *TilePtr =
- Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
- return loadMatrix(TileTy, TilePtr, Align,
+ return loadMatrix(TileTy, TileStart, Align,
Builder.getInt64(MatrixShape.getStride()), IsVolatile,
ResultShape, Builder);
}
@@ -1193,17 +1173,11 @@ public:
Value *Offset = Builder.CreateAdd(
Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
- unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
- Value *EltPtr =
- Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
- Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+ Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
StoreVal.getNumColumns());
- Type *TilePtrTy = PointerType::get(TileTy, AS);
- Value *TilePtr =
- Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
- storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
+ storeMatrix(TileTy, StoreVal, TileStart, MAlign,
Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
}
@@ -1213,7 +1187,7 @@ public:
MaybeAlign MAlign, Value *Stride, bool IsVolatile,
IRBuilder<> &Builder) {
auto VType = cast<VectorType>(Ty);
- Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
+ Value *EltPtr = Ptr;
for (auto Vec : enumerate(StoreVal.vectors())) {
Value *GEP = computeVectorAddr(
EltPtr,
@@ -2180,7 +2154,7 @@ public:
/// Returns true if \p V is a matrix value in the given subprogram.
bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
- /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
+ /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
/// \p SS.
void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
auto M = Inst2Matrix.find(V);
@@ -2201,7 +2175,7 @@ public:
write("<no called fn>");
else {
StringRef Name = CI->getCalledFunction()->getName();
- if (!Name.startswith("llvm.matrix")) {
+ if (!Name.starts_with("llvm.matrix")) {
write(Name);
return;
}