diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 26 |
1 files changed, 19 insertions, 7 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 42c183a6408e..4e4097e13271 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -900,8 +900,7 @@ public: // UndefedInsts and then check that we in fact remove them. SmallSet<Instruction *, 16> UndefedInsts; for (auto *Inst : reverse(ToRemove)) { - for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { - Use &U = *I++; + for (Use &U : llvm::make_early_inc_range(Inst->uses())) { if (auto *Undefed = dyn_cast<Instruction>(U.getUser())) UndefedInsts.insert(Undefed); U.set(UndefValue::get(Inst->getType())); @@ -981,8 +980,9 @@ public: Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); MatrixTy Result; for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { - Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, - Shape.getStride(), EltTy, Builder); + Value *GEP = computeVectorAddr( + EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), + Stride, Shape.getStride(), EltTy, Builder); Value *Vector = Builder.CreateAlignedLoad( VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), IsVolatile, "col.load"); @@ -1071,9 +1071,11 @@ public: auto VType = cast<VectorType>(Ty); Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); for (auto Vec : enumerate(StoreVal.vectors())) { - Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), - Stride, StoreVal.getStride(), - VType->getElementType(), Builder); + Value *GEP = computeVectorAddr( + EltPtr, + Builder.getIntN(Stride->getType()->getScalarSizeInBits(), + Vec.index()), + Stride, StoreVal.getStride(), VType->getElementType(), Builder); Builder.CreateAlignedStore(Vec.value(), GEP, getAlignForIndex(Vec.index(), Stride, VType->getElementType(), @@ -2261,6 +2263,16 @@ PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, return PreservedAnalyses::all(); } +void LowerMatrixIntrinsicsPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << "<"; + if (Minimal) + OS << "minimal"; + OS << ">"; +} + namespace { class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { |
