diff options
Diffstat (limited to 'llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp')
-rw-r--r-- | llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp | 911 |
1 files changed, 823 insertions, 88 deletions
diff --git a/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp b/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp index 9f64af02e698..4d7ad6cd60cb 100644 --- a/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp +++ b/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp @@ -15,6 +15,7 @@ #include "ARM.h" #include "ARMBaseInstrInfo.h" #include "ARMSubtarget.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetPassConfig.h" @@ -37,6 +38,7 @@ #include "llvm/IR/Value.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" +#include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> @@ -67,27 +69,77 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired<TargetPassConfig>(); + AU.addRequired<LoopInfoWrapperPass>(); FunctionPass::getAnalysisUsage(AU); } private: + LoopInfo *LI = nullptr; + // Check this is a valid gather with correct alignment bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, - unsigned Alignment); + Align Alignment); // Check whether Ptr is hidden behind a bitcast and look through it void lookThroughBitcast(Value *&Ptr); // Check for a getelementptr and deduce base and offsets from it, on success // returning the base directly and the offsets indirectly using the Offsets // argument - Value *checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, IRBuilder<> Builder); + Value *checkGEP(Value *&Offsets, Type *Ty, GetElementPtrInst *GEP, + IRBuilder<> &Builder); + // Compute the scale of this gather/scatter instruction + int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); + // If the value is a constant, or derived from constants via additions + // and multilications, return its numeric value + Optional<int64_t> getIfConst(const Value *V); + // If Inst is an add instruction, check whether one summand is a + // constant. If so, scale this constant and return it together with + // the other summand. + std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale); - bool lowerGather(IntrinsicInst *I); + Value *lowerGather(IntrinsicInst *I); // Create a gather from a base + vector of offsets Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, - IRBuilder<> Builder); + Instruction *&Root, IRBuilder<> &Builder); // Create a gather from a vector of pointers Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, - IRBuilder<> Builder); + IRBuilder<> &Builder, int64_t Increment = 0); + // Create an incrementing gather from a vector of pointers + Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, + IRBuilder<> &Builder, + int64_t Increment = 0); + + Value *lowerScatter(IntrinsicInst *I); + // Create a scatter to a base + vector of offsets + Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, + IRBuilder<> &Builder); + // Create a scatter to a vector of pointers + Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, + IRBuilder<> &Builder, + int64_t Increment = 0); + // Create an incrementing scatter from a vector of pointers + Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, + IRBuilder<> &Builder, + int64_t Increment = 0); + + // QI gathers and scatters can increment their offsets on their own if + // the increment is a constant value (digit) + Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr, + Value *Ptr, GetElementPtrInst *GEP, + IRBuilder<> &Builder); + // QI gathers/scatters can increment their offsets on their own if the + // increment is a constant value (digit) - this creates a writeback QI + // gather/scatter + Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, + Value *Ptr, unsigned TypeScale, + IRBuilder<> &Builder); + // Check whether these offsets could be moved out of the loop they're in + bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); + // Pushes the given add out of the loop + void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); + // Pushes the given mul out of the loop + void pushOutMul(PHINode *&Phi, Value *IncrementPerRound, + Value *OffsSecondOperand, unsigned LoopIncrement, + IRBuilder<> &Builder); }; } // end anonymous namespace @@ -103,102 +155,177 @@ Pass *llvm::createMVEGatherScatterLoweringPass() { bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, - unsigned Alignment) { - // Do only allow non-extending gathers for now - if (((NumElements == 4 && ElemSize == 32) || - (NumElements == 8 && ElemSize == 16) || + Align Alignment) { + if (((NumElements == 4 && + (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || + (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || (NumElements == 16 && ElemSize == 8)) && - ElemSize / 8 <= Alignment) + Alignment >= ElemSize / 8) return true; - LLVM_DEBUG(dbgs() << "masked gathers: instruction does not have valid " - << "alignment or vector type \n"); + LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " + << "valid alignment or vector type \n"); return false; } -Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, - IRBuilder<> Builder) { - GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); +Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, + GetElementPtrInst *GEP, + IRBuilder<> &Builder) { if (!GEP) { - LLVM_DEBUG(dbgs() << "masked gathers: no getelementpointer found\n"); + LLVM_DEBUG( + dbgs() << "masked gathers/scatters: no getelementpointer found\n"); return nullptr; } - LLVM_DEBUG(dbgs() << "masked gathers: getelementpointer found. Loading" - << " from base + vector of offsets\n"); + LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." + << " Looking at intrinsic for base + vector of offsets\n"); Value *GEPPtr = GEP->getPointerOperand(); if (GEPPtr->getType()->isVectorTy()) { - LLVM_DEBUG(dbgs() << "masked gathers: gather from a vector of pointers" - << " hidden behind a getelementptr currently not" - << " supported. Expanding.\n"); return nullptr; } if (GEP->getNumOperands() != 2) { - LLVM_DEBUG(dbgs() << "masked gathers: getelementptr with too many" + LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" << " operands. Expanding.\n"); return nullptr; } Offsets = GEP->getOperand(1); - // SExt offsets inside masked gathers are not permitted by the architecture; - // we therefore can't fold them + // Paranoid check whether the number of parallel lanes is the same + assert(cast<FixedVectorType>(Ty)->getNumElements() == + cast<FixedVectorType>(Offsets->getType())->getNumElements()); + // Only <N x i32> offsets can be integrated into an arm gather, any smaller + // type would have to be sign extended by the gep - and arm gathers can only + // zero extend. Additionally, the offsets do have to originate from a zext of + // a vector with element types smaller or equal the type of the gather we're + // looking at + if (Offsets->getType()->getScalarSizeInBits() != 32) + return nullptr; if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets)) Offsets = ZextOffs->getOperand(0); - Type *OffsType = VectorType::getInteger(cast<VectorType>(Ty)); - // If the offset we found does not have the type the intrinsic expects, - // i.e., the same type as the gather itself, we need to convert it (only i - // types) or fall back to expanding the gather - if (OffsType != Offsets->getType()) { - if (OffsType->getScalarSizeInBits() > - Offsets->getType()->getScalarSizeInBits()) { - LLVM_DEBUG(dbgs() << "masked gathers: extending offsets\n"); - Offsets = Builder.CreateZExt(Offsets, OffsType, ""); - } else { - LLVM_DEBUG(dbgs() << "masked gathers: no correct offset type. Can't" - << " create masked gather\n"); + else if (!(cast<FixedVectorType>(Offsets->getType())->getNumElements() == 4 && + Offsets->getType()->getScalarSizeInBits() == 32)) + return nullptr; + + if (Ty != Offsets->getType()) { + if ((Ty->getScalarSizeInBits() < + Offsets->getType()->getScalarSizeInBits())) { + LLVM_DEBUG(dbgs() << "masked gathers/scatters: no correct offset type." + << " Can't create intrinsic.\n"); return nullptr; + } else { + Offsets = Builder.CreateZExt( + Offsets, VectorType::getInteger(cast<VectorType>(Ty))); } } // If none of the checks failed, return the gep's base pointer + LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); return GEPPtr; } void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { // Look through bitcast instruction if #elements is the same if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { - Type *BCTy = BitCast->getType(); - Type *BCSrcTy = BitCast->getOperand(0)->getType(); - if (BCTy->getVectorNumElements() == BCSrcTy->getVectorNumElements()) { - LLVM_DEBUG(dbgs() << "masked gathers: looking through bitcast\n"); + auto *BCTy = cast<FixedVectorType>(BitCast->getType()); + auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType()); + if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { + LLVM_DEBUG( + dbgs() << "masked gathers/scatters: looking through bitcast\n"); Ptr = BitCast->getOperand(0); } } } -bool MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { +int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, + unsigned MemoryElemSize) { + // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, + // or a 8bit, 16bit or 32bit load/store scaled by 1 + if (GEPElemSize == 32 && MemoryElemSize == 32) + return 2; + else if (GEPElemSize == 16 && MemoryElemSize == 16) + return 1; + else if (GEPElemSize == 8) + return 0; + LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " + << "create intrinsic\n"); + return -1; +} + +Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) { + const Constant *C = dyn_cast<Constant>(V); + if (C != nullptr) + return Optional<int64_t>{C->getUniqueInteger().getSExtValue()}; + if (!isa<Instruction>(V)) + return Optional<int64_t>{}; + + const Instruction *I = cast<Instruction>(V); + if (I->getOpcode() == Instruction::Add || + I->getOpcode() == Instruction::Mul) { + Optional<int64_t> Op0 = getIfConst(I->getOperand(0)); + Optional<int64_t> Op1 = getIfConst(I->getOperand(1)); + if (!Op0 || !Op1) + return Optional<int64_t>{}; + if (I->getOpcode() == Instruction::Add) + return Optional<int64_t>{Op0.getValue() + Op1.getValue()}; + if (I->getOpcode() == Instruction::Mul) + return Optional<int64_t>{Op0.getValue() * Op1.getValue()}; + } + return Optional<int64_t>{}; +} + +std::pair<Value *, int64_t> +MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { + std::pair<Value *, int64_t> ReturnFalse = + std::pair<Value *, int64_t>(nullptr, 0); + // At this point, the instruction we're looking at must be an add or we + // bail out + Instruction *Add = dyn_cast<Instruction>(Inst); + if (Add == nullptr || Add->getOpcode() != Instruction::Add) + return ReturnFalse; + + Value *Summand; + Optional<int64_t> Const; + // Find out which operand the value that is increased is + if ((Const = getIfConst(Add->getOperand(0)))) + Summand = Add->getOperand(1); + else if ((Const = getIfConst(Add->getOperand(1)))) + Summand = Add->getOperand(0); + else + return ReturnFalse; + + // Check that the constant is small enough for an incrementing gather + int64_t Immediate = Const.getValue() << TypeScale; + if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) + return ReturnFalse; + + return std::pair<Value *, int64_t>(Summand, Immediate); +} + +Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { using namespace PatternMatch; LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"); // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) // Attempt to turn the masked gather in I into a MVE intrinsic // Potentially optimising the addressing modes as we do so. - Type *Ty = I->getType(); + auto *Ty = cast<FixedVectorType>(I->getType()); Value *Ptr = I->getArgOperand(0); - unsigned Alignment = cast<ConstantInt>(I->getArgOperand(1))->getZExtValue(); + Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue(); Value *Mask = I->getArgOperand(2); Value *PassThru = I->getArgOperand(3); - if (!isLegalTypeAndAlignment(Ty->getVectorNumElements(), - Ty->getScalarSizeInBits(), Alignment)) - return false; + if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), + Alignment)) + return nullptr; lookThroughBitcast(Ptr); assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); IRBuilder<> Builder(I->getContext()); Builder.SetInsertPoint(I); Builder.SetCurrentDebugLocation(I->getDebugLoc()); - Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Builder); + + Instruction *Root = I; + Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); if (!Load) Load = tryCreateMaskedGatherBase(I, Ptr, Builder); if (!Load) - return false; + return nullptr; if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " @@ -206,72 +333,649 @@ bool MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { Load = Builder.CreateSelect(Mask, Load, PassThru); } + Root->replaceAllUsesWith(Load); + Root->eraseFromParent(); + if (Root != I) + // If this was an extending gather, we need to get rid of the sext/zext + // sext/zext as well as of the gather itself + I->eraseFromParent(); + LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"); - I->replaceAllUsesWith(Load); - I->eraseFromParent(); - return true; + return Load; } -Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase( - IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) { +Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I, + Value *Ptr, + IRBuilder<> &Builder, + int64_t Increment) { using namespace PatternMatch; + auto *Ty = cast<FixedVectorType>(I->getType()); LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); - Type *Ty = I->getType(); - if (Ty->getVectorNumElements() != 4) + if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) // Can't build an intrinsic for this return nullptr; Value *Mask = I->getArgOperand(2); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, {Ty, Ptr->getType()}, - {Ptr, Builder.getInt32(0)}); + {Ptr, Builder.getInt32(Increment)}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_base_predicated, {Ty, Ptr->getType(), Mask->getType()}, - {Ptr, Builder.getInt32(0), Mask}); + {Ptr, Builder.getInt32(Increment), Mask}); +} + +Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( + IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { + using namespace PatternMatch; + auto *Ty = cast<FixedVectorType>(I->getType()); + LLVM_DEBUG( + dbgs() + << "masked gathers: loading from vector of pointers with writeback\n"); + if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) + // Can't build an intrinsic for this + return nullptr; + Value *Mask = I->getArgOperand(2); + if (match(Mask, m_One())) + return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, + {Ty, Ptr->getType()}, + {Ptr, Builder.getInt32(Increment)}); + else + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vldr_gather_base_wb_predicated, + {Ty, Ptr->getType(), Mask->getType()}, + {Ptr, Builder.getInt32(Increment), Mask}); } Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( - IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) { + IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { using namespace PatternMatch; - Type *Ty = I->getType(); + + Type *OriginalTy = I->getType(); + Type *ResultTy = OriginalTy; + + unsigned Unsigned = 1; + // The size of the gather was already checked in isLegalTypeAndAlignment; + // if it was not a full vector width an appropriate extend should follow. + auto *Extend = Root; + if (OriginalTy->getPrimitiveSizeInBits() < 128) { + // Only transform gathers with exactly one use + if (!I->hasOneUse()) + return nullptr; + + // The correct root to replace is not the CallInst itself, but the + // instruction which extends it + Extend = cast<Instruction>(*I->users().begin()); + if (isa<SExtInst>(Extend)) { + Unsigned = 0; + } else if (!isa<ZExtInst>(Extend)) { + LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. " + << "Expanding\n"); + return nullptr; + } + LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n"); + ResultTy = Extend->getType(); + // The final size of the gather must be a full vector width + if (ResultTy->getPrimitiveSizeInBits() != 128) { + LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. " + << "Expanding\n"); + return nullptr; + } + } + + GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); Value *Offsets; - Value *BasePtr = checkGEP(Offsets, Ty, Ptr, Builder); + Value *BasePtr = checkGEP(Offsets, ResultTy, GEP, Builder); if (!BasePtr) return nullptr; + // Check whether the offset is a constant increment that could be merged into + // a QI gather + Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); + if (Load) + return Load; - unsigned Scale; - int GEPElemSize = - BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(); - int ResultElemSize = Ty->getScalarSizeInBits(); - // This can be a 32bit load scaled by 4, a 16bit load scaled by 2, or a - // 8bit, 16bit or 32bit load scaled by 1 - if (GEPElemSize == 32 && ResultElemSize == 32) { - Scale = 2; - } else if (GEPElemSize == 16 && ResultElemSize == 16) { - Scale = 1; - } else if (GEPElemSize == 8) { - Scale = 0; - } else { - LLVM_DEBUG(dbgs() << "masked gathers: incorrect scale for load. Can't" - << " create masked gather\n"); + int Scale = computeScale( + BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), + OriginalTy->getScalarSizeInBits()); + if (Scale == -1) return nullptr; - } + Root = Extend; Value *Mask = I->getArgOperand(2); if (!match(Mask, m_One())) return Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_offset_predicated, - {Ty, BasePtr->getType(), Offsets->getType(), Mask->getType()}, - {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()), - Builder.getInt32(Scale), Builder.getInt32(1), Mask}); + {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, + {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), + Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_offset, - {Ty, BasePtr->getType(), Offsets->getType()}, - {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()), - Builder.getInt32(Scale), Builder.getInt32(1)}); + {ResultTy, BasePtr->getType(), Offsets->getType()}, + {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), + Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); +} + +Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { + using namespace PatternMatch; + LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"); + + // @llvm.masked.scatter.*(data, ptrs, alignment, mask) + // Attempt to turn the masked scatter in I into a MVE intrinsic + // Potentially optimising the addressing modes as we do so. + Value *Input = I->getArgOperand(0); + Value *Ptr = I->getArgOperand(1); + Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue(); + auto *Ty = cast<FixedVectorType>(Input->getType()); + + if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), + Alignment)) + return nullptr; + + lookThroughBitcast(Ptr); + assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); + + IRBuilder<> Builder(I->getContext()); + Builder.SetInsertPoint(I); + Builder.SetCurrentDebugLocation(I->getDebugLoc()); + + Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); + if (!Store) + Store = tryCreateMaskedScatterBase(I, Ptr, Builder); + if (!Store) + return nullptr; + + LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"); + I->eraseFromParent(); + return Store; +} + +Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase( + IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { + using namespace PatternMatch; + Value *Input = I->getArgOperand(0); + auto *Ty = cast<FixedVectorType>(Input->getType()); + // Only QR variants allow truncating + if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { + // Can't build an intrinsic for this + return nullptr; + } + Value *Mask = I->getArgOperand(3); + // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) + LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); + if (match(Mask, m_One())) + return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, + {Ptr->getType(), Input->getType()}, + {Ptr, Builder.getInt32(Increment), Input}); + else + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vstr_scatter_base_predicated, + {Ptr->getType(), Input->getType(), Mask->getType()}, + {Ptr, Builder.getInt32(Increment), Input, Mask}); +} + +Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( + IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { + using namespace PatternMatch; + Value *Input = I->getArgOperand(0); + auto *Ty = cast<FixedVectorType>(Input->getType()); + LLVM_DEBUG( + dbgs() + << "masked scatters: storing to a vector of pointers with writeback\n"); + if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) + // Can't build an intrinsic for this + return nullptr; + Value *Mask = I->getArgOperand(3); + if (match(Mask, m_One())) + return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, + {Ptr->getType(), Input->getType()}, + {Ptr, Builder.getInt32(Increment), Input}); + else + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, + {Ptr->getType(), Input->getType(), Mask->getType()}, + {Ptr, Builder.getInt32(Increment), Input, Mask}); +} + +Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( + IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { + using namespace PatternMatch; + Value *Input = I->getArgOperand(0); + Value *Mask = I->getArgOperand(3); + Type *InputTy = Input->getType(); + Type *MemoryTy = InputTy; + LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" + << " to base + vector of offsets\n"); + // If the input has been truncated, try to integrate that trunc into the + // scatter instruction (we don't care about alignment here) + if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) { + Value *PreTrunc = Trunc->getOperand(0); + Type *PreTruncTy = PreTrunc->getType(); + if (PreTruncTy->getPrimitiveSizeInBits() == 128) { + Input = PreTrunc; + InputTy = PreTruncTy; + } + } + if (InputTy->getPrimitiveSizeInBits() != 128) { + LLVM_DEBUG( + dbgs() << "masked scatters: cannot create scatters for non-standard" + << " input types. Expanding.\n"); + return nullptr; + } + + GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); + Value *Offsets; + Value *BasePtr = checkGEP(Offsets, InputTy, GEP, Builder); + if (!BasePtr) + return nullptr; + // Check whether the offset is a constant increment that could be merged into + // a QI gather + Value *Store = + tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); + if (Store) + return Store; + int Scale = computeScale( + BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), + MemoryTy->getScalarSizeInBits()); + if (Scale == -1) + return nullptr; + + if (!match(Mask, m_One())) + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vstr_scatter_offset_predicated, + {BasePtr->getType(), Offsets->getType(), Input->getType(), + Mask->getType()}, + {BasePtr, Offsets, Input, + Builder.getInt32(MemoryTy->getScalarSizeInBits()), + Builder.getInt32(Scale), Mask}); + else + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vstr_scatter_offset, + {BasePtr->getType(), Offsets->getType(), Input->getType()}, + {BasePtr, Offsets, Input, + Builder.getInt32(MemoryTy->getScalarSizeInBits()), + Builder.getInt32(Scale)}); +} + +Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat( + IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP, + IRBuilder<> &Builder) { + FixedVectorType *Ty; + if (I->getIntrinsicID() == Intrinsic::masked_gather) + Ty = cast<FixedVectorType>(I->getType()); + else + Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType()); + // Incrementing gathers only exist for v4i32 + if (Ty->getNumElements() != 4 || + Ty->getScalarSizeInBits() != 32) + return nullptr; + Loop *L = LI->getLoopFor(I->getParent()); + if (L == nullptr) + // Incrementing gathers are not beneficial outside of a loop + return nullptr; + LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " + "wb gather/scatter\n"); + + // The gep was in charge of making sure the offsets are scaled correctly + // - calculate that factor so it can be applied by hand + DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout(); + int TypeScale = + computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()), + DT.getTypeSizeInBits(GEP->getType()) / + cast<FixedVectorType>(GEP->getType())->getNumElements()); + if (TypeScale == -1) + return nullptr; + + if (GEP->hasOneUse()) { + // Only in this case do we want to build a wb gather, because the wb will + // change the phi which does affect other users of the gep (which will still + // be using the phi in the old way) + Value *Load = + tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder); + if (Load != nullptr) + return Load; + } + LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " + "non-wb gather/scatter\n"); + + std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); + if (Add.first == nullptr) + return nullptr; + Value *OffsetsIncoming = Add.first; + int64_t Immediate = Add.second; + + // Make sure the offsets are scaled correctly + Instruction *ScaledOffsets = BinaryOperator::Create( + Instruction::Shl, OffsetsIncoming, + Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)), + "ScaledIndex", I); + // Add the base to the offsets + OffsetsIncoming = BinaryOperator::Create( + Instruction::Add, ScaledOffsets, + Builder.CreateVectorSplat( + Ty->getNumElements(), + Builder.CreatePtrToInt( + BasePtr, + cast<VectorType>(ScaledOffsets->getType())->getElementType())), + "StartIndex", I); + + if (I->getIntrinsicID() == Intrinsic::masked_gather) + return cast<IntrinsicInst>( + tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate)); + else + return cast<IntrinsicInst>( + tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate)); +} + +Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( + IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, + IRBuilder<> &Builder) { + // Check whether this gather's offset is incremented by a constant - if so, + // and the load is of the right type, we can merge this into a QI gather + Loop *L = LI->getLoopFor(I->getParent()); + // Offsets that are worth merging into this instruction will be incremented + // by a constant, thus we're looking for an add of a phi and a constant + PHINode *Phi = dyn_cast<PHINode>(Offsets); + if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || + Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) + // No phi means no IV to write back to; if there is a phi, we expect it + // to have exactly two incoming values; the only phis we are interested in + // will be loop IV's and have exactly two uses, one in their increment and + // one in the gather's gep + return nullptr; + + unsigned IncrementIndex = + Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; + // Look through the phi to the phi increment + Offsets = Phi->getIncomingValue(IncrementIndex); + + std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); + if (Add.first == nullptr) + return nullptr; + Value *OffsetsIncoming = Add.first; + int64_t Immediate = Add.second; + if (OffsetsIncoming != Phi) + // Then the increment we are looking at is not an increment of the + // induction variable, and we don't want to do a writeback + return nullptr; + + Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); + unsigned NumElems = + cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements(); + + // Make sure the offsets are scaled correctly + Instruction *ScaledOffsets = BinaryOperator::Create( + Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), + Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), + "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); + // Add the base to the offsets + OffsetsIncoming = BinaryOperator::Create( + Instruction::Add, ScaledOffsets, + Builder.CreateVectorSplat( + NumElems, + Builder.CreatePtrToInt( + BasePtr, + cast<VectorType>(ScaledOffsets->getType())->getElementType())), + "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); + // The gather is pre-incrementing + OffsetsIncoming = BinaryOperator::Create( + Instruction::Sub, OffsetsIncoming, + Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), + "PreIncrementStartIndex", + &Phi->getIncomingBlock(1 - IncrementIndex)->back()); + Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); + + Builder.SetInsertPoint(I); + + Value *EndResult; + Value *NewInduction; + if (I->getIntrinsicID() == Intrinsic::masked_gather) { + // Build the incrementing gather + Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); + // One value to be handed to whoever uses the gather, one is the loop + // increment + EndResult = Builder.CreateExtractValue(Load, 0, "Gather"); + NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement"); + } else { + // Build the incrementing scatter + NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); + EndResult = NewInduction; + } + Instruction *AddInst = cast<Instruction>(Offsets); + AddInst->replaceAllUsesWith(NewInduction); + AddInst->eraseFromParent(); + Phi->setIncomingValue(IncrementIndex, NewInduction); + + return EndResult; +} + +void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, + Value *OffsSecondOperand, + unsigned StartIndex) { + LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); + Instruction *InsertionPoint = + &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back()); + // Initialize the phi with a vector that contains a sum of the constants + Instruction *NewIndex = BinaryOperator::Create( + Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, + "PushedOutAdd", InsertionPoint); + unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; + + // Order such that start index comes first (this reduces mov's) + Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); + Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), + Phi->getIncomingBlock(IncrementIndex)); + Phi->removeIncomingValue(IncrementIndex); + Phi->removeIncomingValue(StartIndex); +} + +void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi, + Value *IncrementPerRound, + Value *OffsSecondOperand, + unsigned LoopIncrement, + IRBuilder<> &Builder) { + LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); + + // Create a new scalar add outside of the loop and transform it to a splat + // by which loop variable can be incremented + Instruction *InsertionPoint = &cast<Instruction>( + Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); + + // Create a new index + Value *StartIndex = BinaryOperator::Create( + Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), + OffsSecondOperand, "PushedOutMul", InsertionPoint); + + Instruction *Product = + BinaryOperator::Create(Instruction::Mul, IncrementPerRound, + OffsSecondOperand, "Product", InsertionPoint); + // Increment NewIndex by Product instead of the multiplication + Instruction *NewIncrement = BinaryOperator::Create( + Instruction::Add, Phi, Product, "IncrementPushedOutMul", + cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back()) + .getPrevNode()); + + Phi->addIncoming(StartIndex, + Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); + Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); + Phi->removeIncomingValue((unsigned)0); + Phi->removeIncomingValue((unsigned)0); + return; +} + +// Check whether all usages of this instruction are as offsets of +// gathers/scatters or simple arithmetics only used by gathers/scatters +static bool hasAllGatScatUsers(Instruction *I) { + if (I->hasNUses(0)) { + return false; + } + bool Gatscat = true; + for (User *U : I->users()) { + if (!isa<Instruction>(U)) + return false; + if (isa<GetElementPtrInst>(U) || + isGatherScatter(dyn_cast<IntrinsicInst>(U))) { + return Gatscat; + } else { + unsigned OpCode = cast<Instruction>(U)->getOpcode(); + if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) && + hasAllGatScatUsers(cast<Instruction>(U))) { + continue; + } + return false; + } + } + return Gatscat; +} + +bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, + LoopInfo *LI) { + LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"); + // Optimise the addresses of gathers/scatters by moving invariant + // calculations out of the loop + if (!isa<Instruction>(Offsets)) + return false; + Instruction *Offs = cast<Instruction>(Offsets); + if (Offs->getOpcode() != Instruction::Add && + Offs->getOpcode() != Instruction::Mul) + return false; + Loop *L = LI->getLoopFor(BB); + if (L == nullptr) + return false; + if (!Offs->hasOneUse()) { + if (!hasAllGatScatUsers(Offs)) + return false; + } + + // Find out which, if any, operand of the instruction + // is a phi node + PHINode *Phi; + int OffsSecondOp; + if (isa<PHINode>(Offs->getOperand(0))) { + Phi = cast<PHINode>(Offs->getOperand(0)); + OffsSecondOp = 1; + } else if (isa<PHINode>(Offs->getOperand(1))) { + Phi = cast<PHINode>(Offs->getOperand(1)); + OffsSecondOp = 0; + } else { + bool Changed = true; + if (isa<Instruction>(Offs->getOperand(0)) && + L->contains(cast<Instruction>(Offs->getOperand(0)))) + Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); + if (isa<Instruction>(Offs->getOperand(1)) && + L->contains(cast<Instruction>(Offs->getOperand(1)))) + Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); + if (!Changed) { + return false; + } else { + if (isa<PHINode>(Offs->getOperand(0))) { + Phi = cast<PHINode>(Offs->getOperand(0)); + OffsSecondOp = 1; + } else if (isa<PHINode>(Offs->getOperand(1))) { + Phi = cast<PHINode>(Offs->getOperand(1)); + OffsSecondOp = 0; + } else { + return false; + } + } + } + // A phi node we want to perform this function on should be from the + // loop header, and shouldn't have more than 2 incoming values + if (Phi->getParent() != L->getHeader() || + Phi->getNumIncomingValues() != 2) + return false; + + // The phi must be an induction variable + Instruction *Op; + int IncrementingBlock = -1; + + for (int i = 0; i < 2; i++) + if ((Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) != nullptr) + if (Op->getOpcode() == Instruction::Add && + (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi)) + IncrementingBlock = i; + if (IncrementingBlock == -1) + return false; + + Instruction *IncInstruction = + cast<Instruction>(Phi->getIncomingValue(IncrementingBlock)); + + // If the phi is not used by anything else, we can just adapt it when + // replacing the instruction; if it is, we'll have to duplicate it + PHINode *NewPhi; + Value *IncrementPerRound = IncInstruction->getOperand( + (IncInstruction->getOperand(0) == Phi) ? 1 : 0); + + // Get the value that is added to/multiplied with the phi + Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); + + if (IncrementPerRound->getType() != OffsSecondOperand->getType()) + // Something has gone wrong, abort + return false; + + // Only proceed if the increment per round is a constant or an instruction + // which does not originate from within the loop + if (!isa<Constant>(IncrementPerRound) && + !(isa<Instruction>(IncrementPerRound) && + !L->contains(cast<Instruction>(IncrementPerRound)))) + return false; + + if (Phi->getNumUses() == 2) { + // No other users -> reuse existing phi (One user is the instruction + // we're looking at, the other is the phi increment) + if (IncInstruction->getNumUses() != 1) { + // If the incrementing instruction does have more users than + // our phi, we need to copy it + IncInstruction = BinaryOperator::Create( + Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, + IncrementPerRound, "LoopIncrement", IncInstruction); + Phi->setIncomingValue(IncrementingBlock, IncInstruction); + } + NewPhi = Phi; + } else { + // There are other users -> create a new phi + NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi); + std::vector<Value *> Increases; + // Copy the incoming values of the old phi + NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), + Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); + IncInstruction = BinaryOperator::Create( + Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, + IncrementPerRound, "LoopIncrement", IncInstruction); + NewPhi->addIncoming(IncInstruction, + Phi->getIncomingBlock(IncrementingBlock)); + IncrementingBlock = 1; + } + + IRBuilder<> Builder(BB->getContext()); + Builder.SetInsertPoint(Phi); + Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); + + switch (Offs->getOpcode()) { + case Instruction::Add: + pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); + break; + case Instruction::Mul: + pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock, + Builder); + break; + default: + return false; + } + LLVM_DEBUG( + dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n"); + + // The instruction has now been "absorbed" into the phi value + Offs->replaceAllUsesWith(NewPhi); + if (Offs->hasNUses(0)) + Offs->eraseFromParent(); + // Clean up the old increment in case it's unused because we built a new + // one + if (IncInstruction->hasNUses(0)) + IncInstruction->eraseFromParent(); + + return true; } bool MVEGatherScatterLowering::runOnFunction(Function &F) { @@ -282,20 +986,51 @@ bool MVEGatherScatterLowering::runOnFunction(Function &F) { auto *ST = &TM.getSubtarget<ARMSubtarget>(F); if (!ST->hasMVEIntegerOps()) return false; + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); SmallVector<IntrinsicInst *, 4> Gathers; + SmallVector<IntrinsicInst *, 4> Scatters; + + bool Changed = false; + for (BasicBlock &BB : F) { for (Instruction &I : BB) { IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); - if (II && II->getIntrinsicID() == Intrinsic::masked_gather) + if (II && II->getIntrinsicID() == Intrinsic::masked_gather) { Gathers.push_back(II); + if (isa<GetElementPtrInst>(II->getArgOperand(0))) + Changed |= optimiseOffsets( + cast<Instruction>(II->getArgOperand(0))->getOperand(1), + II->getParent(), LI); + } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) { + Scatters.push_back(II); + if (isa<GetElementPtrInst>(II->getArgOperand(1))) + Changed |= optimiseOffsets( + cast<Instruction>(II->getArgOperand(1))->getOperand(1), + II->getParent(), LI); + } } } - if (Gathers.empty()) - return false; + for (unsigned i = 0; i < Gathers.size(); i++) { + IntrinsicInst *I = Gathers[i]; + Value *L = lowerGather(I); + if (L == nullptr) + continue; - for (IntrinsicInst *I : Gathers) - lowerGather(I); + // Get rid of any now dead instructions + SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent()); + Changed = true; + } - return true; + for (unsigned i = 0; i < Scatters.size(); i++) { + IntrinsicInst *I = Scatters[i]; + Value *S = lowerScatter(I); + if (S == nullptr) + continue; + + // Get rid of any now dead instructions + SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent()); + Changed = true; + } + return Changed; } |