diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp | 765 |
1 files changed, 765 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp b/contrib/llvm-project/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp new file mode 100644 index 000000000000..7776dffb4e9c --- /dev/null +++ b/contrib/llvm-project/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp @@ -0,0 +1,765 @@ +//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// +// instrinsics +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass replaces masked memory intrinsics - when unsupported by the target +// - with a chain of basic blocks, that deal with the elements one-by-one if the +// appropriate mask bit is set. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include <algorithm> +#include <cassert> + +using namespace llvm; + +#define DEBUG_TYPE "scalarize-masked-mem-intrin" + +namespace { + +class ScalarizeMaskedMemIntrin : public FunctionPass { + const TargetTransformInfo *TTI = nullptr; + +public: + static char ID; // Pass identification, replacement for typeid + + explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) { + initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + + StringRef getPassName() const override { + return "Scalarize Masked Memory Intrinsics"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetTransformInfoWrapperPass>(); + } + +private: + bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT); + bool optimizeCallInst(CallInst *CI, bool &ModifiedDT); +}; + +} // end anonymous namespace + +char ScalarizeMaskedMemIntrin::ID = 0; + +INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE, + "Scalarize unsupported masked memory intrinsics", false, false) + +FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() { + return new ScalarizeMaskedMemIntrin(); +} + +static bool isConstantIntVector(Value *Mask) { + Constant *C = dyn_cast<Constant>(Mask); + if (!C) + return false; + + unsigned NumElts = Mask->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *CElt = C->getAggregateElement(i); + if (!CElt || !isa<ConstantInt>(CElt)) + return false; + } + + return true; +} + +// Translate a masked load intrinsic like +// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, +// <16 x i1> %mask, <16 x i32> %passthru) +// to a chain of basic blocks, with loading element one-by-one if +// the appropriate mask bit is set +// +// %1 = bitcast i8* %addr to i32* +// %2 = extractelement <16 x i1> %mask, i32 0 +// br i1 %2, label %cond.load, label %else +// +// cond.load: ; preds = %0 +// %3 = getelementptr i32* %1, i32 0 +// %4 = load i32* %3 +// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 +// br label %else +// +// else: ; preds = %0, %cond.load +// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ] +// %6 = extractelement <16 x i1> %mask, i32 1 +// br i1 %6, label %cond.load1, label %else2 +// +// cond.load1: ; preds = %else +// %7 = getelementptr i32* %1, i32 1 +// %8 = load i32* %7 +// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 +// br label %else2 +// +// else2: ; preds = %else, %cond.load1 +// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] +// %10 = extractelement <16 x i1> %mask, i32 2 +// br i1 %10, label %cond.load4, label %else5 +// +static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) { + Value *Ptr = CI->getArgOperand(0); + Value *Alignment = CI->getArgOperand(1); + Value *Mask = CI->getArgOperand(2); + Value *Src0 = CI->getArgOperand(3); + + unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); + VectorType *VecType = cast<VectorType>(CI->getType()); + + Type *EltTy = VecType->getElementType(); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + + Builder.SetInsertPoint(InsertPt); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + // Short-cut if the mask is all-true. + if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { + Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); + CI->replaceAllUsesWith(NewI); + CI->eraseFromParent(); + return; + } + + // Adjust alignment for the scalar instruction. + AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); + // Bitcast %addr from i8* to EltTy* + Type *NewPtrType = + EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); + Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); + unsigned VectorWidth = VecType->getNumElements(); + + // The result vector + Value *VResult = Src0; + + if (isConstantIntVector(Mask)) { + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) + continue; + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal); + VResult = Builder.CreateInsertElement(VResult, Load, Idx); + } + CI->replaceAllUsesWith(VResult); + CI->eraseFromParent(); + return; + } + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] + // %mask_1 = extractelement <16 x i1> %mask, i32 Idx + // br i1 %mask_1, label %cond.load, label %else + // + + Value *Predicate = Builder.CreateExtractElement(Mask, Idx); + + // Create "cond" block + // + // %EltAddr = getelementptr i32* %1, i32 0 + // %Elt = load i32* %EltAddr + // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx + // + BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), + "cond.load"); + Builder.SetInsertPoint(InsertPt); + + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal); + Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = + CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + BasicBlock *PrevIfBlock = IfBlock; + IfBlock = NewIfBlock; + + // Create the phi to join the new and previous value. + PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); + Phi->addIncoming(NewVResult, CondBlock); + Phi->addIncoming(VResult, PrevIfBlock); + VResult = Phi; + } + + CI->replaceAllUsesWith(VResult); + CI->eraseFromParent(); + + ModifiedDT = true; +} + +// Translate a masked store intrinsic, like +// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, +// <16 x i1> %mask) +// to a chain of basic blocks, that stores element one-by-one if +// the appropriate mask bit is set +// +// %1 = bitcast i8* %addr to i32* +// %2 = extractelement <16 x i1> %mask, i32 0 +// br i1 %2, label %cond.store, label %else +// +// cond.store: ; preds = %0 +// %3 = extractelement <16 x i32> %val, i32 0 +// %4 = getelementptr i32* %1, i32 0 +// store i32 %3, i32* %4 +// br label %else +// +// else: ; preds = %0, %cond.store +// %5 = extractelement <16 x i1> %mask, i32 1 +// br i1 %5, label %cond.store1, label %else2 +// +// cond.store1: ; preds = %else +// %6 = extractelement <16 x i32> %val, i32 1 +// %7 = getelementptr i32* %1, i32 1 +// store i32 %6, i32* %7 +// br label %else2 +// . . . +static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) { + Value *Src = CI->getArgOperand(0); + Value *Ptr = CI->getArgOperand(1); + Value *Alignment = CI->getArgOperand(2); + Value *Mask = CI->getArgOperand(3); + + unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); + VectorType *VecType = cast<VectorType>(Src->getType()); + + Type *EltTy = VecType->getElementType(); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + Builder.SetInsertPoint(InsertPt); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + // Short-cut if the mask is all-true. + if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { + Builder.CreateAlignedStore(Src, Ptr, AlignVal); + CI->eraseFromParent(); + return; + } + + // Adjust alignment for the scalar instruction. + AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); + // Bitcast %addr from i8* to EltTy* + Type *NewPtrType = + EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); + Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); + unsigned VectorWidth = VecType->getNumElements(); + + if (isConstantIntVector(Mask)) { + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) + continue; + Value *OneElt = Builder.CreateExtractElement(Src, Idx); + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + Builder.CreateAlignedStore(OneElt, Gep, AlignVal); + } + CI->eraseFromParent(); + return; + } + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %mask_1 = extractelement <16 x i1> %mask, i32 Idx + // br i1 %mask_1, label %cond.store, label %else + // + Value *Predicate = Builder.CreateExtractElement(Mask, Idx); + + // Create "cond" block + // + // %OneElt = extractelement <16 x i32> %Src, i32 Idx + // %EltAddr = getelementptr i32* %1, i32 0 + // %store i32 %OneElt, i32* %EltAddr + // + BasicBlock *CondBlock = + IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); + Builder.SetInsertPoint(InsertPt); + + Value *OneElt = Builder.CreateExtractElement(Src, Idx); + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + Builder.CreateAlignedStore(OneElt, Gep, AlignVal); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = + CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + IfBlock = NewIfBlock; + } + CI->eraseFromParent(); + + ModifiedDT = true; +} + +// Translate a masked gather intrinsic like +// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, +// <16 x i1> %Mask, <16 x i32> %Src) +// to a chain of basic blocks, with loading element one-by-one if +// the appropriate mask bit is set +// +// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind +// %Mask0 = extractelement <16 x i1> %Mask, i32 0 +// br i1 %Mask0, label %cond.load, label %else +// +// cond.load: +// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 +// %Load0 = load i32, i32* %Ptr0, align 4 +// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0 +// br label %else +// +// else: +// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0] +// %Mask1 = extractelement <16 x i1> %Mask, i32 1 +// br i1 %Mask1, label %cond.load1, label %else2 +// +// cond.load1: +// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 +// %Load1 = load i32, i32* %Ptr1, align 4 +// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 +// br label %else2 +// . . . +// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src +// ret <16 x i32> %Result +static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) { + Value *Ptrs = CI->getArgOperand(0); + Value *Alignment = CI->getArgOperand(1); + Value *Mask = CI->getArgOperand(2); + Value *Src0 = CI->getArgOperand(3); + + VectorType *VecType = cast<VectorType>(CI->getType()); + Type *EltTy = VecType->getElementType(); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + Builder.SetInsertPoint(InsertPt); + unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); + + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + // The result vector + Value *VResult = Src0; + unsigned VectorWidth = VecType->getNumElements(); + + // Shorten the way if the mask is a vector of constants. + if (isConstantIntVector(Mask)) { + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) + continue; + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); + LoadInst *Load = + Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); + VResult = + Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); + } + CI->replaceAllUsesWith(VResult); + CI->eraseFromParent(); + return; + } + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %Mask1 = extractelement <16 x i1> %Mask, i32 1 + // br i1 %Mask1, label %cond.load, label %else + // + + Value *Predicate = + Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); + + // Create "cond" block + // + // %EltAddr = getelementptr i32* %1, i32 0 + // %Elt = load i32* %EltAddr + // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx + // + BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load"); + Builder.SetInsertPoint(InsertPt); + + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); + LoadInst *Load = + Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); + Value *NewVResult = + Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + BasicBlock *PrevIfBlock = IfBlock; + IfBlock = NewIfBlock; + + PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); + Phi->addIncoming(NewVResult, CondBlock); + Phi->addIncoming(VResult, PrevIfBlock); + VResult = Phi; + } + + CI->replaceAllUsesWith(VResult); + CI->eraseFromParent(); + + ModifiedDT = true; +} + +// Translate a masked scatter intrinsic, like +// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, +// <16 x i1> %Mask) +// to a chain of basic blocks, that stores element one-by-one if +// the appropriate mask bit is set. +// +// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind +// %Mask0 = extractelement <16 x i1> %Mask, i32 0 +// br i1 %Mask0, label %cond.store, label %else +// +// cond.store: +// %Elt0 = extractelement <16 x i32> %Src, i32 0 +// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 +// store i32 %Elt0, i32* %Ptr0, align 4 +// br label %else +// +// else: +// %Mask1 = extractelement <16 x i1> %Mask, i32 1 +// br i1 %Mask1, label %cond.store1, label %else2 +// +// cond.store1: +// %Elt1 = extractelement <16 x i32> %Src, i32 1 +// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 +// store i32 %Elt1, i32* %Ptr1, align 4 +// br label %else2 +// . . . +static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) { + Value *Src = CI->getArgOperand(0); + Value *Ptrs = CI->getArgOperand(1); + Value *Alignment = CI->getArgOperand(2); + Value *Mask = CI->getArgOperand(3); + + assert(isa<VectorType>(Src->getType()) && + "Unexpected data type in masked scatter intrinsic"); + assert(isa<VectorType>(Ptrs->getType()) && + isa<PointerType>(Ptrs->getType()->getVectorElementType()) && + "Vector of pointers is expected in masked scatter intrinsic"); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + Builder.SetInsertPoint(InsertPt); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); + unsigned VectorWidth = Src->getType()->getVectorNumElements(); + + // Shorten the way if the mask is a vector of constants. + if (isConstantIntVector(Mask)) { + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) + continue; + Value *OneElt = + Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); + Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); + } + CI->eraseFromParent(); + return; + } + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx + // br i1 %Mask1, label %cond.store, label %else + // + Value *Predicate = + Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); + + // Create "cond" block + // + // %Elt1 = extractelement <16 x i32> %Src, i32 1 + // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 + // %store i32 %Elt1, i32* %Ptr1 + // + BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store"); + Builder.SetInsertPoint(InsertPt); + + Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); + Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + IfBlock = NewIfBlock; + } + CI->eraseFromParent(); + + ModifiedDT = true; +} + +static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) { + Value *Ptr = CI->getArgOperand(0); + Value *Mask = CI->getArgOperand(1); + Value *PassThru = CI->getArgOperand(2); + + VectorType *VecType = cast<VectorType>(CI->getType()); + + Type *EltTy = VecType->getElementType(); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + + Builder.SetInsertPoint(InsertPt); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + unsigned VectorWidth = VecType->getNumElements(); + + // The result vector + Value *VResult = PassThru; + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] + // %mask_1 = extractelement <16 x i1> %mask, i32 Idx + // br i1 %mask_1, label %cond.load, label %else + // + + Value *Predicate = + Builder.CreateExtractElement(Mask, Idx); + + // Create "cond" block + // + // %EltAddr = getelementptr i32* %1, i32 0 + // %Elt = load i32* %EltAddr + // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx + // + BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), + "cond.load"); + Builder.SetInsertPoint(InsertPt); + + LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1); + Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); + + // Move the pointer if there are more blocks to come. + Value *NewPtr; + if ((Idx + 1) != VectorWidth) + NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = + CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + BasicBlock *PrevIfBlock = IfBlock; + IfBlock = NewIfBlock; + + // Create the phi to join the new and previous value. + PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); + ResultPhi->addIncoming(NewVResult, CondBlock); + ResultPhi->addIncoming(VResult, PrevIfBlock); + VResult = ResultPhi; + + // Add a PHI for the pointer if this isn't the last iteration. + if ((Idx + 1) != VectorWidth) { + PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); + PtrPhi->addIncoming(NewPtr, CondBlock); + PtrPhi->addIncoming(Ptr, PrevIfBlock); + Ptr = PtrPhi; + } + } + + CI->replaceAllUsesWith(VResult); + CI->eraseFromParent(); + + ModifiedDT = true; +} + +static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) { + Value *Src = CI->getArgOperand(0); + Value *Ptr = CI->getArgOperand(1); + Value *Mask = CI->getArgOperand(2); + + VectorType *VecType = cast<VectorType>(Src->getType()); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + + Builder.SetInsertPoint(InsertPt); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + Type *EltTy = VecType->getVectorElementType(); + + unsigned VectorWidth = VecType->getNumElements(); + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %mask_1 = extractelement <16 x i1> %mask, i32 Idx + // br i1 %mask_1, label %cond.store, label %else + // + Value *Predicate = Builder.CreateExtractElement(Mask, Idx); + + // Create "cond" block + // + // %OneElt = extractelement <16 x i32> %Src, i32 Idx + // %EltAddr = getelementptr i32* %1, i32 0 + // %store i32 %OneElt, i32* %EltAddr + // + BasicBlock *CondBlock = + IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); + Builder.SetInsertPoint(InsertPt); + + Value *OneElt = Builder.CreateExtractElement(Src, Idx); + Builder.CreateAlignedStore(OneElt, Ptr, 1); + + // Move the pointer if there are more blocks to come. + Value *NewPtr; + if ((Idx + 1) != VectorWidth) + NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = + CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + BasicBlock *PrevIfBlock = IfBlock; + IfBlock = NewIfBlock; + + // Add a PHI for the pointer if this isn't the last iteration. + if ((Idx + 1) != VectorWidth) { + PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); + PtrPhi->addIncoming(NewPtr, CondBlock); + PtrPhi->addIncoming(Ptr, PrevIfBlock); + Ptr = PtrPhi; + } + } + CI->eraseFromParent(); + + ModifiedDT = true; +} + +bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) { + bool EverMadeChange = false; + + TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + + bool MadeChange = true; + while (MadeChange) { + MadeChange = false; + for (Function::iterator I = F.begin(); I != F.end();) { + BasicBlock *BB = &*I++; + bool ModifiedDTOnIteration = false; + MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration); + + // Restart BB iteration if the dominator tree of the Function was changed + if (ModifiedDTOnIteration) + break; + } + + EverMadeChange |= MadeChange; + } + + return EverMadeChange; +} + +bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) { + bool MadeChange = false; + + BasicBlock::iterator CurInstIterator = BB.begin(); + while (CurInstIterator != BB.end()) { + if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) + MadeChange |= optimizeCallInst(CI, ModifiedDT); + if (ModifiedDT) + return true; + } + + return MadeChange; +} + +bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI, + bool &ModifiedDT) { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); + if (II) { + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::masked_load: + // Scalarize unsupported vector masked load + if (TTI->isLegalMaskedLoad(CI->getType())) + return false; + scalarizeMaskedLoad(CI, ModifiedDT); + return true; + case Intrinsic::masked_store: + if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) + return false; + scalarizeMaskedStore(CI, ModifiedDT); + return true; + case Intrinsic::masked_gather: + if (TTI->isLegalMaskedGather(CI->getType())) + return false; + scalarizeMaskedGather(CI, ModifiedDT); + return true; + case Intrinsic::masked_scatter: + if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) + return false; + scalarizeMaskedScatter(CI, ModifiedDT); + return true; + case Intrinsic::masked_expandload: + if (TTI->isLegalMaskedExpandLoad(CI->getType())) + return false; + scalarizeMaskedExpandLoad(CI, ModifiedDT); + return true; + case Intrinsic::masked_compressstore: + if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) + return false; + scalarizeMaskedCompressStore(CI, ModifiedDT); + return true; + } + } + + return false; +} |