diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 46 |
1 files changed, 10 insertions, 36 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index f46ea6a20afa..72b9db1e73d7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -19,6 +19,7 @@ #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/DomTreeUpdater.h" @@ -36,12 +37,9 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MatrixBuilder.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/MatrixUtils.h" @@ -180,7 +178,6 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, assert((!isa<ConstantInt>(Stride) || cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && "Stride must be >= the number of elements in the result vector."); - unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); // Compute the start of the vector with index VecIdx as VecIdx * Stride. Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); @@ -192,11 +189,7 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, else VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); - // Cast elementwise vector start pointer to a pointer to a vector - // (EltType x NumElements)*. - auto *VecType = FixedVectorType::get(EltType, NumElements); - Type *VecPtrType = PointerType::get(VecType, AS); - return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); + return VecStart; } /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. @@ -1063,13 +1056,6 @@ public: return Changed; } - /// Turns \p BasePtr into an elementwise pointer to \p EltType. - Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { - unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); - Type *EltPtrType = PointerType::get(EltType, AS); - return Builder.CreatePointerCast(BasePtr, EltPtrType); - } - /// Replace intrinsic calls bool VisitCallInst(CallInst *Inst) { if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) @@ -1121,7 +1107,7 @@ public: auto *VType = cast<VectorType>(Ty); Type *EltTy = VType->getElementType(); Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); - Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); + Value *EltPtr = Ptr; MatrixTy Result; for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { Value *GEP = computeVectorAddr( @@ -1147,17 +1133,11 @@ public: Value *Offset = Builder.CreateAdd( Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); - unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); - Value *EltPtr = - Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); - Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); + Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * ResultShape.NumColumns); - Type *TilePtrTy = PointerType::get(TileTy, AS); - Value *TilePtr = - Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); - return loadMatrix(TileTy, TilePtr, Align, + return loadMatrix(TileTy, TileStart, Align, Builder.getInt64(MatrixShape.getStride()), IsVolatile, ResultShape, Builder); } @@ -1193,17 +1173,11 @@ public: Value *Offset = Builder.CreateAdd( Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); - unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); - Value *EltPtr = - Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); - Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); + Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * StoreVal.getNumColumns()); - Type *TilePtrTy = PointerType::get(TileTy, AS); - Value *TilePtr = - Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); - storeMatrix(TileTy, StoreVal, TilePtr, MAlign, + storeMatrix(TileTy, StoreVal, TileStart, MAlign, Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); } @@ -1213,7 +1187,7 @@ public: MaybeAlign MAlign, Value *Stride, bool IsVolatile, IRBuilder<> &Builder) { auto VType = cast<VectorType>(Ty); - Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); + Value *EltPtr = Ptr; for (auto Vec : enumerate(StoreVal.vectors())) { Value *GEP = computeVectorAddr( EltPtr, @@ -2180,7 +2154,7 @@ public: /// Returns true if \p V is a matrix value in the given subprogram. bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } - /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to + /// If \p V is a matrix value, print its shape as NumRows x NumColumns to /// \p SS. void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { auto M = Inst2Matrix.find(V); @@ -2201,7 +2175,7 @@ public: write("<no called fn>"); else { StringRef Name = CI->getCalledFunction()->getName(); - if (!Name.startswith("llvm.matrix")) { + if (!Name.starts_with("llvm.matrix")) { write(Name); return; } |
