summaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp765
1 files changed, 765 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp b/contrib/llvm-project/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp
new file mode 100644
index 000000000000..7776dffb4e9c
--- /dev/null
+++ b/contrib/llvm-project/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp
@@ -0,0 +1,765 @@
+//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
+// instrinsics
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass replaces masked memory intrinsics - when unsupported by the target
+// - with a chain of basic blocks, that deal with the elements one-by-one if the
+// appropriate mask bit is set.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/CodeGen/TargetSubtargetInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <cassert>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "scalarize-masked-mem-intrin"
+
+namespace {
+
+class ScalarizeMaskedMemIntrin : public FunctionPass {
+ const TargetTransformInfo *TTI = nullptr;
+
+public:
+ static char ID; // Pass identification, replacement for typeid
+
+ explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
+ initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override;
+
+ StringRef getPassName() const override {
+ return "Scalarize Masked Memory Intrinsics";
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<TargetTransformInfoWrapperPass>();
+ }
+
+private:
+ bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
+ bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
+};
+
+} // end anonymous namespace
+
+char ScalarizeMaskedMemIntrin::ID = 0;
+
+INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
+ "Scalarize unsupported masked memory intrinsics", false, false)
+
+FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
+ return new ScalarizeMaskedMemIntrin();
+}
+
+static bool isConstantIntVector(Value *Mask) {
+ Constant *C = dyn_cast<Constant>(Mask);
+ if (!C)
+ return false;
+
+ unsigned NumElts = Mask->getType()->getVectorNumElements();
+ for (unsigned i = 0; i != NumElts; ++i) {
+ Constant *CElt = C->getAggregateElement(i);
+ if (!CElt || !isa<ConstantInt>(CElt))
+ return false;
+ }
+
+ return true;
+}
+
+// Translate a masked load intrinsic like
+// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
+// <16 x i1> %mask, <16 x i32> %passthru)
+// to a chain of basic blocks, with loading element one-by-one if
+// the appropriate mask bit is set
+//
+// %1 = bitcast i8* %addr to i32*
+// %2 = extractelement <16 x i1> %mask, i32 0
+// br i1 %2, label %cond.load, label %else
+//
+// cond.load: ; preds = %0
+// %3 = getelementptr i32* %1, i32 0
+// %4 = load i32* %3
+// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
+// br label %else
+//
+// else: ; preds = %0, %cond.load
+// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
+// %6 = extractelement <16 x i1> %mask, i32 1
+// br i1 %6, label %cond.load1, label %else2
+//
+// cond.load1: ; preds = %else
+// %7 = getelementptr i32* %1, i32 1
+// %8 = load i32* %7
+// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
+// br label %else2
+//
+// else2: ; preds = %else, %cond.load1
+// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
+// %10 = extractelement <16 x i1> %mask, i32 2
+// br i1 %10, label %cond.load4, label %else5
+//
+static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
+ Value *Ptr = CI->getArgOperand(0);
+ Value *Alignment = CI->getArgOperand(1);
+ Value *Mask = CI->getArgOperand(2);
+ Value *Src0 = CI->getArgOperand(3);
+
+ unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
+ VectorType *VecType = cast<VectorType>(CI->getType());
+
+ Type *EltTy = VecType->getElementType();
+
+ IRBuilder<> Builder(CI->getContext());
+ Instruction *InsertPt = CI;
+ BasicBlock *IfBlock = CI->getParent();
+
+ Builder.SetInsertPoint(InsertPt);
+ Builder.SetCurrentDebugLocation(CI->getDebugLoc());
+
+ // Short-cut if the mask is all-true.
+ if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
+ Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
+ CI->replaceAllUsesWith(NewI);
+ CI->eraseFromParent();
+ return;
+ }
+
+ // Adjust alignment for the scalar instruction.
+ AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
+ // Bitcast %addr from i8* to EltTy*
+ Type *NewPtrType =
+ EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
+ Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
+ unsigned VectorWidth = VecType->getNumElements();
+
+ // The result vector
+ Value *VResult = Src0;
+
+ if (isConstantIntVector(Mask)) {
+ for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
+ if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
+ continue;
+ Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
+ LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
+ VResult = Builder.CreateInsertElement(VResult, Load, Idx);
+ }
+ CI->replaceAllUsesWith(VResult);
+ CI->eraseFromParent();
+ return;
+ }
+
+ for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
+ // Fill the "else" block, created in the previous iteration
+ //
+ // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
+ // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
+ // br i1 %mask_1, label %cond.load, label %else
+ //
+
+ Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
+
+ // Create "cond" block
+ //
+ // %EltAddr = getelementptr i32* %1, i32 0
+ // %Elt = load i32* %EltAddr
+ // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
+ //
+ BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
+ "cond.load");
+ Builder.SetInsertPoint(InsertPt);
+
+ Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
+ LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
+ Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
+
+ // Create "else" block, fill it in the next iteration
+ BasicBlock *NewIfBlock =
+ CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
+ Builder.SetInsertPoint(InsertPt);
+ Instruction *OldBr = IfBlock->getTerminator();
+ BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
+ OldBr->eraseFromParent();
+ BasicBlock *PrevIfBlock = IfBlock;
+ IfBlock = NewIfBlock;
+
+ // Create the phi to join the new and previous value.
+ PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
+ Phi->addIncoming(NewVResult, CondBlock);
+ Phi->addIncoming(VResult, PrevIfBlock);
+ VResult = Phi;
+ }
+
+ CI->replaceAllUsesWith(VResult);
+ CI->eraseFromParent();
+
+ ModifiedDT = true;
+}
+
+// Translate a masked store intrinsic, like
+// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
+// <16 x i1> %mask)
+// to a chain of basic blocks, that stores element one-by-one if
+// the appropriate mask bit is set
+//
+// %1 = bitcast i8* %addr to i32*
+// %2 = extractelement <16 x i1> %mask, i32 0
+// br i1 %2, label %cond.store, label %else
+//
+// cond.store: ; preds = %0
+// %3 = extractelement <16 x i32> %val, i32 0
+// %4 = getelementptr i32* %1, i32 0
+// store i32 %3, i32* %4
+// br label %else
+//
+// else: ; preds = %0, %cond.store
+// %5 = extractelement <16 x i1> %mask, i32 1
+// br i1 %5, label %cond.store1, label %else2
+//
+// cond.store1: ; preds = %else
+// %6 = extractelement <16 x i32> %val, i32 1
+// %7 = getelementptr i32* %1, i32 1
+// store i32 %6, i32* %7
+// br label %else2
+// . . .
+static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
+ Value *Src = CI->getArgOperand(0);
+ Value *Ptr = CI->getArgOperand(1);
+ Value *Alignment = CI->getArgOperand(2);
+ Value *Mask = CI->getArgOperand(3);
+
+ unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
+ VectorType *VecType = cast<VectorType>(Src->getType());
+
+ Type *EltTy = VecType->getElementType();
+
+ IRBuilder<> Builder(CI->getContext());
+ Instruction *InsertPt = CI;
+ BasicBlock *IfBlock = CI->getParent();
+ Builder.SetInsertPoint(InsertPt);
+ Builder.SetCurrentDebugLocation(CI->getDebugLoc());
+
+ // Short-cut if the mask is all-true.
+ if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
+ Builder.CreateAlignedStore(Src, Ptr, AlignVal);
+ CI->eraseFromParent();
+ return;
+ }
+
+ // Adjust alignment for the scalar instruction.
+ AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
+ // Bitcast %addr from i8* to EltTy*
+ Type *NewPtrType =
+ EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
+ Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
+ unsigned VectorWidth = VecType->getNumElements();
+
+ if (isConstantIntVector(Mask)) {
+ for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
+ if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
+ continue;
+ Value *OneElt = Builder.CreateExtractElement(Src, Idx);
+ Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
+ Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
+ }
+ CI->eraseFromParent();
+ return;
+ }
+
+ for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
+ // Fill the "else" block, created in the previous iteration
+ //
+ // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
+ // br i1 %mask_1, label %cond.store, label %else
+ //
+ Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
+
+ // Create "cond" block
+ //
+ // %OneElt = extractelement <16 x i32> %Src, i32 Idx
+ // %EltAddr = getelementptr i32* %1, i32 0
+ // %store i32 %OneElt, i32* %EltAddr
+ //
+ BasicBlock *CondBlock =
+ IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
+ Builder.SetInsertPoint(InsertPt);
+
+ Value *OneElt = Builder.CreateExtractElement(Src, Idx);
+ Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
+ Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
+
+ // Create "else" block, fill it in the next iteration
+ BasicBlock *NewIfBlock =
+ CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
+ Builder.SetInsertPoint(InsertPt);
+ Instruction *OldBr = IfBlock->getTerminator();
+ BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
+ OldBr->eraseFromParent();
+ IfBlock = NewIfBlock;
+ }
+ CI->eraseFromParent();
+
+ ModifiedDT = true;
+}
+
+// Translate a masked gather intrinsic like
+// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
+// <16 x i1> %Mask, <16 x i32> %Src)
+// to a chain of basic blocks, with loading element one-by-one if
+// the appropriate mask bit is set
+//
+// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
+// %Mask0 = extractelement <16 x i1> %Mask, i32 0
+// br i1 %Mask0, label %cond.load, label %else
+//
+// cond.load:
+// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
+// %Load0 = load i32, i32* %Ptr0, align 4
+// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
+// br label %else
+//
+// else:
+// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
+// %Mask1 = extractelement <16 x i1> %Mask, i32 1
+// br i1 %Mask1, label %cond.load1, label %else2
+//
+// cond.load1:
+// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
+// %Load1 = load i32, i32* %Ptr1, align 4
+// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
+// br label %else2
+// . . .
+// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
+// ret <16 x i32> %Result
+static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
+ Value *Ptrs = CI->getArgOperand(0);
+ Value *Alignment = CI->getArgOperand(1);
+ Value *Mask = CI->getArgOperand(2);
+ Value *Src0 = CI->getArgOperand(3);
+
+ VectorType *VecType = cast<VectorType>(CI->getType());
+ Type *EltTy = VecType->getElementType();
+
+ IRBuilder<> Builder(CI->getContext());
+ Instruction *InsertPt = CI;
+ BasicBlock *IfBlock = CI->getParent();
+ Builder.SetInsertPoint(InsertPt);
+ unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
+
+ Builder.SetCurrentDebugLocation(CI->getDebugLoc());
+
+ // The result vector
+ Value *VResult = Src0;
+ unsigned VectorWidth = VecType->getNumElements();
+
+ // Shorten the way if the mask is a vector of constants.
+ if (isConstantIntVector(Mask)) {
+ for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
+ if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
+ continue;
+ Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
+ LoadInst *Load =
+ Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
+ VResult =
+ Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
+ }
+ CI->replaceAllUsesWith(VResult);
+ CI->eraseFromParent();
+ return;
+ }
+
+ for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
+ // Fill the "else" block, created in the previous iteration
+ //
+ // %Mask1 = extractelement <16 x i1> %Mask, i32 1
+ // br i1 %Mask1, label %cond.load, label %else
+ //
+
+ Value *Predicate =
+ Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
+
+ // Create "cond" block
+ //
+ // %EltAddr = getelementptr i32* %1, i32 0
+ // %Elt = load i32* %EltAddr
+ // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
+ //
+ BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
+ Builder.SetInsertPoint(InsertPt);
+
+ Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
+ LoadInst *Load =
+ Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
+ Value *NewVResult =
+ Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
+
+ // Create "else" block, fill it in the next iteration
+ BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
+ Builder.SetInsertPoint(InsertPt);
+ Instruction *OldBr = IfBlock->getTerminator();
+ BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
+ OldBr->eraseFromParent();
+ BasicBlock *PrevIfBlock = IfBlock;
+ IfBlock = NewIfBlock;
+
+ PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
+ Phi->addIncoming(NewVResult, CondBlock);
+ Phi->addIncoming(VResult, PrevIfBlock);
+ VResult = Phi;
+ }
+
+ CI->replaceAllUsesWith(VResult);
+ CI->eraseFromParent();
+
+ ModifiedDT = true;
+}
+
+// Translate a masked scatter intrinsic, like
+// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
+// <16 x i1> %Mask)
+// to a chain of basic blocks, that stores element one-by-one if
+// the appropriate mask bit is set.
+//
+// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
+// %Mask0 = extractelement <16 x i1> %Mask, i32 0
+// br i1 %Mask0, label %cond.store, label %else
+//
+// cond.store:
+// %Elt0 = extractelement <16 x i32> %Src, i32 0
+// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
+// store i32 %Elt0, i32* %Ptr0, align 4
+// br label %else
+//
+// else:
+// %Mask1 = extractelement <16 x i1> %Mask, i32 1
+// br i1 %Mask1, label %cond.store1, label %else2
+//
+// cond.store1:
+// %Elt1 = extractelement <16 x i32> %Src, i32 1
+// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
+// store i32 %Elt1, i32* %Ptr1, align 4
+// br label %else2
+// . . .
+static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
+ Value *Src = CI->getArgOperand(0);
+ Value *Ptrs = CI->getArgOperand(1);
+ Value *Alignment = CI->getArgOperand(2);
+ Value *Mask = CI->getArgOperand(3);
+
+ assert(isa<VectorType>(Src->getType()) &&
+ "Unexpected data type in masked scatter intrinsic");
+ assert(isa<VectorType>(Ptrs->getType()) &&
+ isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
+ "Vector of pointers is expected in masked scatter intrinsic");
+
+ IRBuilder<> Builder(CI->getContext());
+ Instruction *InsertPt = CI;
+ BasicBlock *IfBlock = CI->getParent();
+ Builder.SetInsertPoint(InsertPt);
+ Builder.SetCurrentDebugLocation(CI->getDebugLoc());
+
+ unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
+ unsigned VectorWidth = Src->getType()->getVectorNumElements();
+
+ // Shorten the way if the mask is a vector of constants.
+ if (isConstantIntVector(Mask)) {
+ for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
+ if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
+ continue;
+ Value *OneElt =
+ Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
+ Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
+ Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
+ }
+ CI->eraseFromParent();
+ return;
+ }
+
+ for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
+ // Fill the "else" block, created in the previous iteration
+ //
+ // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
+ // br i1 %Mask1, label %cond.store, label %else
+ //
+ Value *Predicate =
+ Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
+
+ // Create "cond" block
+ //
+ // %Elt1 = extractelement <16 x i32> %Src, i32 1
+ // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
+ // %store i32 %Elt1, i32* %Ptr1
+ //
+ BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
+ Builder.SetInsertPoint(InsertPt);
+
+ Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
+ Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
+ Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
+
+ // Create "else" block, fill it in the next iteration
+ BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
+ Builder.SetInsertPoint(InsertPt);
+ Instruction *OldBr = IfBlock->getTerminator();
+ BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
+ OldBr->eraseFromParent();
+ IfBlock = NewIfBlock;
+ }
+ CI->eraseFromParent();
+
+ ModifiedDT = true;
+}
+
+static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
+ Value *Ptr = CI->getArgOperand(0);
+ Value *Mask = CI->getArgOperand(1);
+ Value *PassThru = CI->getArgOperand(2);
+
+ VectorType *VecType = cast<VectorType>(CI->getType());
+
+ Type *EltTy = VecType->getElementType();
+
+ IRBuilder<> Builder(CI->getContext());
+ Instruction *InsertPt = CI;
+ BasicBlock *IfBlock = CI->getParent();
+
+ Builder.SetInsertPoint(InsertPt);
+ Builder.SetCurrentDebugLocation(CI->getDebugLoc());
+
+ unsigned VectorWidth = VecType->getNumElements();
+
+ // The result vector
+ Value *VResult = PassThru;
+
+ for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
+ // Fill the "else" block, created in the previous iteration
+ //
+ // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
+ // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
+ // br i1 %mask_1, label %cond.load, label %else
+ //
+
+ Value *Predicate =
+ Builder.CreateExtractElement(Mask, Idx);
+
+ // Create "cond" block
+ //
+ // %EltAddr = getelementptr i32* %1, i32 0
+ // %Elt = load i32* %EltAddr
+ // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
+ //
+ BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
+ "cond.load");
+ Builder.SetInsertPoint(InsertPt);
+
+ LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
+ Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
+
+ // Move the pointer if there are more blocks to come.
+ Value *NewPtr;
+ if ((Idx + 1) != VectorWidth)
+ NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
+
+ // Create "else" block, fill it in the next iteration
+ BasicBlock *NewIfBlock =
+ CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
+ Builder.SetInsertPoint(InsertPt);
+ Instruction *OldBr = IfBlock->getTerminator();
+ BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
+ OldBr->eraseFromParent();
+ BasicBlock *PrevIfBlock = IfBlock;
+ IfBlock = NewIfBlock;
+
+ // Create the phi to join the new and previous value.
+ PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
+ ResultPhi->addIncoming(NewVResult, CondBlock);
+ ResultPhi->addIncoming(VResult, PrevIfBlock);
+ VResult = ResultPhi;
+
+ // Add a PHI for the pointer if this isn't the last iteration.
+ if ((Idx + 1) != VectorWidth) {
+ PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
+ PtrPhi->addIncoming(NewPtr, CondBlock);
+ PtrPhi->addIncoming(Ptr, PrevIfBlock);
+ Ptr = PtrPhi;
+ }
+ }
+
+ CI->replaceAllUsesWith(VResult);
+ CI->eraseFromParent();
+
+ ModifiedDT = true;
+}
+
+static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
+ Value *Src = CI->getArgOperand(0);
+ Value *Ptr = CI->getArgOperand(1);
+ Value *Mask = CI->getArgOperand(2);
+
+ VectorType *VecType = cast<VectorType>(Src->getType());
+
+ IRBuilder<> Builder(CI->getContext());
+ Instruction *InsertPt = CI;
+ BasicBlock *IfBlock = CI->getParent();
+
+ Builder.SetInsertPoint(InsertPt);
+ Builder.SetCurrentDebugLocation(CI->getDebugLoc());
+
+ Type *EltTy = VecType->getVectorElementType();
+
+ unsigned VectorWidth = VecType->getNumElements();
+
+ for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
+ // Fill the "else" block, created in the previous iteration
+ //
+ // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
+ // br i1 %mask_1, label %cond.store, label %else
+ //
+ Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
+
+ // Create "cond" block
+ //
+ // %OneElt = extractelement <16 x i32> %Src, i32 Idx
+ // %EltAddr = getelementptr i32* %1, i32 0
+ // %store i32 %OneElt, i32* %EltAddr
+ //
+ BasicBlock *CondBlock =
+ IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
+ Builder.SetInsertPoint(InsertPt);
+
+ Value *OneElt = Builder.CreateExtractElement(Src, Idx);
+ Builder.CreateAlignedStore(OneElt, Ptr, 1);
+
+ // Move the pointer if there are more blocks to come.
+ Value *NewPtr;
+ if ((Idx + 1) != VectorWidth)
+ NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
+
+ // Create "else" block, fill it in the next iteration
+ BasicBlock *NewIfBlock =
+ CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
+ Builder.SetInsertPoint(InsertPt);
+ Instruction *OldBr = IfBlock->getTerminator();
+ BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
+ OldBr->eraseFromParent();
+ BasicBlock *PrevIfBlock = IfBlock;
+ IfBlock = NewIfBlock;
+
+ // Add a PHI for the pointer if this isn't the last iteration.
+ if ((Idx + 1) != VectorWidth) {
+ PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
+ PtrPhi->addIncoming(NewPtr, CondBlock);
+ PtrPhi->addIncoming(Ptr, PrevIfBlock);
+ Ptr = PtrPhi;
+ }
+ }
+ CI->eraseFromParent();
+
+ ModifiedDT = true;
+}
+
+bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
+ bool EverMadeChange = false;
+
+ TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+
+ bool MadeChange = true;
+ while (MadeChange) {
+ MadeChange = false;
+ for (Function::iterator I = F.begin(); I != F.end();) {
+ BasicBlock *BB = &*I++;
+ bool ModifiedDTOnIteration = false;
+ MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
+
+ // Restart BB iteration if the dominator tree of the Function was changed
+ if (ModifiedDTOnIteration)
+ break;
+ }
+
+ EverMadeChange |= MadeChange;
+ }
+
+ return EverMadeChange;
+}
+
+bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
+ bool MadeChange = false;
+
+ BasicBlock::iterator CurInstIterator = BB.begin();
+ while (CurInstIterator != BB.end()) {
+ if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
+ MadeChange |= optimizeCallInst(CI, ModifiedDT);
+ if (ModifiedDT)
+ return true;
+ }
+
+ return MadeChange;
+}
+
+bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
+ bool &ModifiedDT) {
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
+ if (II) {
+ switch (II->getIntrinsicID()) {
+ default:
+ break;
+ case Intrinsic::masked_load:
+ // Scalarize unsupported vector masked load
+ if (TTI->isLegalMaskedLoad(CI->getType()))
+ return false;
+ scalarizeMaskedLoad(CI, ModifiedDT);
+ return true;
+ case Intrinsic::masked_store:
+ if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType()))
+ return false;
+ scalarizeMaskedStore(CI, ModifiedDT);
+ return true;
+ case Intrinsic::masked_gather:
+ if (TTI->isLegalMaskedGather(CI->getType()))
+ return false;
+ scalarizeMaskedGather(CI, ModifiedDT);
+ return true;
+ case Intrinsic::masked_scatter:
+ if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType()))
+ return false;
+ scalarizeMaskedScatter(CI, ModifiedDT);
+ return true;
+ case Intrinsic::masked_expandload:
+ if (TTI->isLegalMaskedExpandLoad(CI->getType()))
+ return false;
+ scalarizeMaskedExpandLoad(CI, ModifiedDT);
+ return true;
+ case Intrinsic::masked_compressstore:
+ if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
+ return false;
+ scalarizeMaskedCompressStore(CI, ModifiedDT);
+ return true;
+ }
+ }
+
+ return false;
+}