diff options
Diffstat (limited to 'llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp')
| -rw-r--r-- | llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp | 896 | 
1 files changed, 896 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp new file mode 100644 index 000000000000..b4037499d7d1 --- /dev/null +++ b/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp @@ -0,0 +1,896 @@ +//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// +//                                    instrinsics +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass replaces masked memory intrinsics - when unsupported by the target +// - with a chain of basic blocks, that deal with the elements one-by-one if the +// appropriate mask bit is set. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include <algorithm> +#include <cassert> + +using namespace llvm; + +#define DEBUG_TYPE "scalarize-masked-mem-intrin" + +namespace { + +class ScalarizeMaskedMemIntrin : public FunctionPass { +  const TargetTransformInfo *TTI = nullptr; + +public: +  static char ID; // Pass identification, replacement for typeid + +  explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) { +    initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry()); +  } + +  bool runOnFunction(Function &F) override; + +  StringRef getPassName() const override { +    return "Scalarize Masked Memory Intrinsics"; +  } + +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.addRequired<TargetTransformInfoWrapperPass>(); +  } + +private: +  bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT); +  bool optimizeCallInst(CallInst *CI, bool &ModifiedDT); +}; + +} // end anonymous namespace + +char ScalarizeMaskedMemIntrin::ID = 0; + +INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE, +                "Scalarize unsupported masked memory intrinsics", false, false) + +FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() { +  return new ScalarizeMaskedMemIntrin(); +} + +static bool isConstantIntVector(Value *Mask) { +  Constant *C = dyn_cast<Constant>(Mask); +  if (!C) +    return false; + +  unsigned NumElts = Mask->getType()->getVectorNumElements(); +  for (unsigned i = 0; i != NumElts; ++i) { +    Constant *CElt = C->getAggregateElement(i); +    if (!CElt || !isa<ConstantInt>(CElt)) +      return false; +  } + +  return true; +} + +// Translate a masked load intrinsic like +// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, +//                               <16 x i1> %mask, <16 x i32> %passthru) +// to a chain of basic blocks, with loading element one-by-one if +// the appropriate mask bit is set +// +//  %1 = bitcast i8* %addr to i32* +//  %2 = extractelement <16 x i1> %mask, i32 0 +//  br i1 %2, label %cond.load, label %else +// +// cond.load:                                        ; preds = %0 +//  %3 = getelementptr i32* %1, i32 0 +//  %4 = load i32* %3 +//  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 +//  br label %else +// +// else:                                             ; preds = %0, %cond.load +//  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ] +//  %6 = extractelement <16 x i1> %mask, i32 1 +//  br i1 %6, label %cond.load1, label %else2 +// +// cond.load1:                                       ; preds = %else +//  %7 = getelementptr i32* %1, i32 1 +//  %8 = load i32* %7 +//  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 +//  br label %else2 +// +// else2:                                          ; preds = %else, %cond.load1 +//  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] +//  %10 = extractelement <16 x i1> %mask, i32 2 +//  br i1 %10, label %cond.load4, label %else5 +// +static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) { +  Value *Ptr = CI->getArgOperand(0); +  Value *Alignment = CI->getArgOperand(1); +  Value *Mask = CI->getArgOperand(2); +  Value *Src0 = CI->getArgOperand(3); + +  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); +  VectorType *VecType = cast<VectorType>(CI->getType()); + +  Type *EltTy = VecType->getElementType(); + +  IRBuilder<> Builder(CI->getContext()); +  Instruction *InsertPt = CI; +  BasicBlock *IfBlock = CI->getParent(); + +  Builder.SetInsertPoint(InsertPt); +  Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + +  // Short-cut if the mask is all-true. +  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { +    Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); +    CI->replaceAllUsesWith(NewI); +    CI->eraseFromParent(); +    return; +  } + +  // Adjust alignment for the scalar instruction. +  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); +  // Bitcast %addr from i8* to EltTy* +  Type *NewPtrType = +      EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); +  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); +  unsigned VectorWidth = VecType->getNumElements(); + +  // The result vector +  Value *VResult = Src0; + +  if (isConstantIntVector(Mask)) { +    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { +      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) +        continue; +      Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); +      LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal); +      VResult = Builder.CreateInsertElement(VResult, Load, Idx); +    } +    CI->replaceAllUsesWith(VResult); +    CI->eraseFromParent(); +    return; +  } + +  // If the mask is not v1i1, use scalar bit test operations. This generates +  // better results on X86 at least. +  Value *SclrMask; +  if (VectorWidth != 1) { +    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); +    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); +  } + +  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { +    // Fill the "else" block, created in the previous iteration +    // +    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] +    //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx +    //  %cond = icmp ne i16 %mask_1, 0 +    //  br i1 %mask_1, label %cond.load, label %else +    // +    Value *Predicate; +    if (VectorWidth != 1) { +      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); +      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), +                                       Builder.getIntN(VectorWidth, 0)); +    } else { +      Predicate = Builder.CreateExtractElement(Mask, Idx); +    } + +    // Create "cond" block +    // +    //  %EltAddr = getelementptr i32* %1, i32 0 +    //  %Elt = load i32* %EltAddr +    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx +    // +    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), +                                                     "cond.load"); +    Builder.SetInsertPoint(InsertPt); + +    Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); +    LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal); +    Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); + +    // Create "else" block, fill it in the next iteration +    BasicBlock *NewIfBlock = +        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); +    Builder.SetInsertPoint(InsertPt); +    Instruction *OldBr = IfBlock->getTerminator(); +    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); +    OldBr->eraseFromParent(); +    BasicBlock *PrevIfBlock = IfBlock; +    IfBlock = NewIfBlock; + +    // Create the phi to join the new and previous value. +    PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); +    Phi->addIncoming(NewVResult, CondBlock); +    Phi->addIncoming(VResult, PrevIfBlock); +    VResult = Phi; +  } + +  CI->replaceAllUsesWith(VResult); +  CI->eraseFromParent(); + +  ModifiedDT = true; +} + +// Translate a masked store intrinsic, like +// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, +//                               <16 x i1> %mask) +// to a chain of basic blocks, that stores element one-by-one if +// the appropriate mask bit is set +// +//   %1 = bitcast i8* %addr to i32* +//   %2 = extractelement <16 x i1> %mask, i32 0 +//   br i1 %2, label %cond.store, label %else +// +// cond.store:                                       ; preds = %0 +//   %3 = extractelement <16 x i32> %val, i32 0 +//   %4 = getelementptr i32* %1, i32 0 +//   store i32 %3, i32* %4 +//   br label %else +// +// else:                                             ; preds = %0, %cond.store +//   %5 = extractelement <16 x i1> %mask, i32 1 +//   br i1 %5, label %cond.store1, label %else2 +// +// cond.store1:                                      ; preds = %else +//   %6 = extractelement <16 x i32> %val, i32 1 +//   %7 = getelementptr i32* %1, i32 1 +//   store i32 %6, i32* %7 +//   br label %else2 +//   . . . +static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) { +  Value *Src = CI->getArgOperand(0); +  Value *Ptr = CI->getArgOperand(1); +  Value *Alignment = CI->getArgOperand(2); +  Value *Mask = CI->getArgOperand(3); + +  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); +  VectorType *VecType = cast<VectorType>(Src->getType()); + +  Type *EltTy = VecType->getElementType(); + +  IRBuilder<> Builder(CI->getContext()); +  Instruction *InsertPt = CI; +  BasicBlock *IfBlock = CI->getParent(); +  Builder.SetInsertPoint(InsertPt); +  Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + +  // Short-cut if the mask is all-true. +  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { +    Builder.CreateAlignedStore(Src, Ptr, AlignVal); +    CI->eraseFromParent(); +    return; +  } + +  // Adjust alignment for the scalar instruction. +  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); +  // Bitcast %addr from i8* to EltTy* +  Type *NewPtrType = +      EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); +  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); +  unsigned VectorWidth = VecType->getNumElements(); + +  if (isConstantIntVector(Mask)) { +    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { +      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) +        continue; +      Value *OneElt = Builder.CreateExtractElement(Src, Idx); +      Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); +      Builder.CreateAlignedStore(OneElt, Gep, AlignVal); +    } +    CI->eraseFromParent(); +    return; +  } + +  // If the mask is not v1i1, use scalar bit test operations. This generates +  // better results on X86 at least. +  Value *SclrMask; +  if (VectorWidth != 1) { +    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); +    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); +  } + +  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { +    // Fill the "else" block, created in the previous iteration +    // +    //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx +    //  %cond = icmp ne i16 %mask_1, 0 +    //  br i1 %mask_1, label %cond.store, label %else +    // +    Value *Predicate; +    if (VectorWidth != 1) { +      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); +      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), +                                       Builder.getIntN(VectorWidth, 0)); +    } else { +      Predicate = Builder.CreateExtractElement(Mask, Idx); +    } + +    // Create "cond" block +    // +    //  %OneElt = extractelement <16 x i32> %Src, i32 Idx +    //  %EltAddr = getelementptr i32* %1, i32 0 +    //  %store i32 %OneElt, i32* %EltAddr +    // +    BasicBlock *CondBlock = +        IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); +    Builder.SetInsertPoint(InsertPt); + +    Value *OneElt = Builder.CreateExtractElement(Src, Idx); +    Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); +    Builder.CreateAlignedStore(OneElt, Gep, AlignVal); + +    // Create "else" block, fill it in the next iteration +    BasicBlock *NewIfBlock = +        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); +    Builder.SetInsertPoint(InsertPt); +    Instruction *OldBr = IfBlock->getTerminator(); +    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); +    OldBr->eraseFromParent(); +    IfBlock = NewIfBlock; +  } +  CI->eraseFromParent(); + +  ModifiedDT = true; +} + +// Translate a masked gather intrinsic like +// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, +//                               <16 x i1> %Mask, <16 x i32> %Src) +// to a chain of basic blocks, with loading element one-by-one if +// the appropriate mask bit is set +// +// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind +// %Mask0 = extractelement <16 x i1> %Mask, i32 0 +// br i1 %Mask0, label %cond.load, label %else +// +// cond.load: +// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 +// %Load0 = load i32, i32* %Ptr0, align 4 +// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0 +// br label %else +// +// else: +// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0] +// %Mask1 = extractelement <16 x i1> %Mask, i32 1 +// br i1 %Mask1, label %cond.load1, label %else2 +// +// cond.load1: +// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 +// %Load1 = load i32, i32* %Ptr1, align 4 +// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 +// br label %else2 +// . . . +// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src +// ret <16 x i32> %Result +static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) { +  Value *Ptrs = CI->getArgOperand(0); +  Value *Alignment = CI->getArgOperand(1); +  Value *Mask = CI->getArgOperand(2); +  Value *Src0 = CI->getArgOperand(3); + +  VectorType *VecType = cast<VectorType>(CI->getType()); +  Type *EltTy = VecType->getElementType(); + +  IRBuilder<> Builder(CI->getContext()); +  Instruction *InsertPt = CI; +  BasicBlock *IfBlock = CI->getParent(); +  Builder.SetInsertPoint(InsertPt); +  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); + +  Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + +  // The result vector +  Value *VResult = Src0; +  unsigned VectorWidth = VecType->getNumElements(); + +  // Shorten the way if the mask is a vector of constants. +  if (isConstantIntVector(Mask)) { +    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { +      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) +        continue; +      Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); +      LoadInst *Load = +          Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); +      VResult = +          Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); +    } +    CI->replaceAllUsesWith(VResult); +    CI->eraseFromParent(); +    return; +  } + +  // If the mask is not v1i1, use scalar bit test operations. This generates +  // better results on X86 at least. +  Value *SclrMask; +  if (VectorWidth != 1) { +    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); +    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); +  } + +  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { +    // Fill the "else" block, created in the previous iteration +    // +    //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx +    //  %cond = icmp ne i16 %mask_1, 0 +    //  br i1 %Mask1, label %cond.load, label %else +    // + +    Value *Predicate; +    if (VectorWidth != 1) { +      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); +      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), +                                       Builder.getIntN(VectorWidth, 0)); +    } else { +      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); +    } + +    // Create "cond" block +    // +    //  %EltAddr = getelementptr i32* %1, i32 0 +    //  %Elt = load i32* %EltAddr +    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx +    // +    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load"); +    Builder.SetInsertPoint(InsertPt); + +    Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); +    LoadInst *Load = +        Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); +    Value *NewVResult = +        Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); + +    // Create "else" block, fill it in the next iteration +    BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); +    Builder.SetInsertPoint(InsertPt); +    Instruction *OldBr = IfBlock->getTerminator(); +    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); +    OldBr->eraseFromParent(); +    BasicBlock *PrevIfBlock = IfBlock; +    IfBlock = NewIfBlock; + +    PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); +    Phi->addIncoming(NewVResult, CondBlock); +    Phi->addIncoming(VResult, PrevIfBlock); +    VResult = Phi; +  } + +  CI->replaceAllUsesWith(VResult); +  CI->eraseFromParent(); + +  ModifiedDT = true; +} + +// Translate a masked scatter intrinsic, like +// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, +//                                  <16 x i1> %Mask) +// to a chain of basic blocks, that stores element one-by-one if +// the appropriate mask bit is set. +// +// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind +// %Mask0 = extractelement <16 x i1> %Mask, i32 0 +// br i1 %Mask0, label %cond.store, label %else +// +// cond.store: +// %Elt0 = extractelement <16 x i32> %Src, i32 0 +// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 +// store i32 %Elt0, i32* %Ptr0, align 4 +// br label %else +// +// else: +// %Mask1 = extractelement <16 x i1> %Mask, i32 1 +// br i1 %Mask1, label %cond.store1, label %else2 +// +// cond.store1: +// %Elt1 = extractelement <16 x i32> %Src, i32 1 +// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 +// store i32 %Elt1, i32* %Ptr1, align 4 +// br label %else2 +//   . . . +static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) { +  Value *Src = CI->getArgOperand(0); +  Value *Ptrs = CI->getArgOperand(1); +  Value *Alignment = CI->getArgOperand(2); +  Value *Mask = CI->getArgOperand(3); + +  assert(isa<VectorType>(Src->getType()) && +         "Unexpected data type in masked scatter intrinsic"); +  assert(isa<VectorType>(Ptrs->getType()) && +         isa<PointerType>(Ptrs->getType()->getVectorElementType()) && +         "Vector of pointers is expected in masked scatter intrinsic"); + +  IRBuilder<> Builder(CI->getContext()); +  Instruction *InsertPt = CI; +  BasicBlock *IfBlock = CI->getParent(); +  Builder.SetInsertPoint(InsertPt); +  Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + +  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); +  unsigned VectorWidth = Src->getType()->getVectorNumElements(); + +  // Shorten the way if the mask is a vector of constants. +  if (isConstantIntVector(Mask)) { +    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { +      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) +        continue; +      Value *OneElt = +          Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); +      Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); +      Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); +    } +    CI->eraseFromParent(); +    return; +  } + +  // If the mask is not v1i1, use scalar bit test operations. This generates +  // better results on X86 at least. +  Value *SclrMask; +  if (VectorWidth != 1) { +    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); +    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); +  } + +  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { +    // Fill the "else" block, created in the previous iteration +    // +    //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx +    //  %cond = icmp ne i16 %mask_1, 0 +    //  br i1 %Mask1, label %cond.store, label %else +    // +    Value *Predicate; +    if (VectorWidth != 1) { +      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); +      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), +                                       Builder.getIntN(VectorWidth, 0)); +    } else { +      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); +    } + +    // Create "cond" block +    // +    //  %Elt1 = extractelement <16 x i32> %Src, i32 1 +    //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 +    //  %store i32 %Elt1, i32* %Ptr1 +    // +    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store"); +    Builder.SetInsertPoint(InsertPt); + +    Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); +    Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); +    Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); + +    // Create "else" block, fill it in the next iteration +    BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); +    Builder.SetInsertPoint(InsertPt); +    Instruction *OldBr = IfBlock->getTerminator(); +    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); +    OldBr->eraseFromParent(); +    IfBlock = NewIfBlock; +  } +  CI->eraseFromParent(); + +  ModifiedDT = true; +} + +static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) { +  Value *Ptr = CI->getArgOperand(0); +  Value *Mask = CI->getArgOperand(1); +  Value *PassThru = CI->getArgOperand(2); + +  VectorType *VecType = cast<VectorType>(CI->getType()); + +  Type *EltTy = VecType->getElementType(); + +  IRBuilder<> Builder(CI->getContext()); +  Instruction *InsertPt = CI; +  BasicBlock *IfBlock = CI->getParent(); + +  Builder.SetInsertPoint(InsertPt); +  Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + +  unsigned VectorWidth = VecType->getNumElements(); + +  // The result vector +  Value *VResult = PassThru; + +  // Shorten the way if the mask is a vector of constants. +  if (isConstantIntVector(Mask)) { +    unsigned MemIndex = 0; +    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { +      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) +        continue; +      Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); +      LoadInst *Load = +          Builder.CreateAlignedLoad(EltTy, NewPtr, 1, "Load" + Twine(Idx)); +      VResult = +          Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); +      ++MemIndex; +    } +    CI->replaceAllUsesWith(VResult); +    CI->eraseFromParent(); +    return; +  } + +  // If the mask is not v1i1, use scalar bit test operations. This generates +  // better results on X86 at least. +  Value *SclrMask; +  if (VectorWidth != 1) { +    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); +    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); +  } + +  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { +    // Fill the "else" block, created in the previous iteration +    // +    //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] +    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx +    //  br i1 %mask_1, label %cond.load, label %else +    // + +    Value *Predicate; +    if (VectorWidth != 1) { +      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); +      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), +                                       Builder.getIntN(VectorWidth, 0)); +    } else { +      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); +    } + +    // Create "cond" block +    // +    //  %EltAddr = getelementptr i32* %1, i32 0 +    //  %Elt = load i32* %EltAddr +    //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx +    // +    BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), +                                                     "cond.load"); +    Builder.SetInsertPoint(InsertPt); + +    LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1); +    Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); + +    // Move the pointer if there are more blocks to come. +    Value *NewPtr; +    if ((Idx + 1) != VectorWidth) +      NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); + +    // Create "else" block, fill it in the next iteration +    BasicBlock *NewIfBlock = +        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); +    Builder.SetInsertPoint(InsertPt); +    Instruction *OldBr = IfBlock->getTerminator(); +    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); +    OldBr->eraseFromParent(); +    BasicBlock *PrevIfBlock = IfBlock; +    IfBlock = NewIfBlock; + +    // Create the phi to join the new and previous value. +    PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); +    ResultPhi->addIncoming(NewVResult, CondBlock); +    ResultPhi->addIncoming(VResult, PrevIfBlock); +    VResult = ResultPhi; + +    // Add a PHI for the pointer if this isn't the last iteration. +    if ((Idx + 1) != VectorWidth) { +      PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); +      PtrPhi->addIncoming(NewPtr, CondBlock); +      PtrPhi->addIncoming(Ptr, PrevIfBlock); +      Ptr = PtrPhi; +    } +  } + +  CI->replaceAllUsesWith(VResult); +  CI->eraseFromParent(); + +  ModifiedDT = true; +} + +static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) { +  Value *Src = CI->getArgOperand(0); +  Value *Ptr = CI->getArgOperand(1); +  Value *Mask = CI->getArgOperand(2); + +  VectorType *VecType = cast<VectorType>(Src->getType()); + +  IRBuilder<> Builder(CI->getContext()); +  Instruction *InsertPt = CI; +  BasicBlock *IfBlock = CI->getParent(); + +  Builder.SetInsertPoint(InsertPt); +  Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + +  Type *EltTy = VecType->getVectorElementType(); + +  unsigned VectorWidth = VecType->getNumElements(); + +  // Shorten the way if the mask is a vector of constants. +  if (isConstantIntVector(Mask)) { +    unsigned MemIndex = 0; +    for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { +      if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) +        continue; +      Value *OneElt = +          Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); +      Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); +      Builder.CreateAlignedStore(OneElt, NewPtr, 1); +      ++MemIndex; +    } +    CI->eraseFromParent(); +    return; +  } + +  // If the mask is not v1i1, use scalar bit test operations. This generates +  // better results on X86 at least. +  Value *SclrMask; +  if (VectorWidth != 1) { +    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); +    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); +  } + +  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { +    // Fill the "else" block, created in the previous iteration +    // +    //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx +    //  br i1 %mask_1, label %cond.store, label %else +    // +    Value *Predicate; +    if (VectorWidth != 1) { +      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); +      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), +                                       Builder.getIntN(VectorWidth, 0)); +    } else { +      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); +    } + +    // Create "cond" block +    // +    //  %OneElt = extractelement <16 x i32> %Src, i32 Idx +    //  %EltAddr = getelementptr i32* %1, i32 0 +    //  %store i32 %OneElt, i32* %EltAddr +    // +    BasicBlock *CondBlock = +        IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); +    Builder.SetInsertPoint(InsertPt); + +    Value *OneElt = Builder.CreateExtractElement(Src, Idx); +    Builder.CreateAlignedStore(OneElt, Ptr, 1); + +    // Move the pointer if there are more blocks to come. +    Value *NewPtr; +    if ((Idx + 1) != VectorWidth) +      NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); + +    // Create "else" block, fill it in the next iteration +    BasicBlock *NewIfBlock = +        CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); +    Builder.SetInsertPoint(InsertPt); +    Instruction *OldBr = IfBlock->getTerminator(); +    BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); +    OldBr->eraseFromParent(); +    BasicBlock *PrevIfBlock = IfBlock; +    IfBlock = NewIfBlock; + +    // Add a PHI for the pointer if this isn't the last iteration. +    if ((Idx + 1) != VectorWidth) { +      PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); +      PtrPhi->addIncoming(NewPtr, CondBlock); +      PtrPhi->addIncoming(Ptr, PrevIfBlock); +      Ptr = PtrPhi; +    } +  } +  CI->eraseFromParent(); + +  ModifiedDT = true; +} + +bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) { +  bool EverMadeChange = false; + +  TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + +  bool MadeChange = true; +  while (MadeChange) { +    MadeChange = false; +    for (Function::iterator I = F.begin(); I != F.end();) { +      BasicBlock *BB = &*I++; +      bool ModifiedDTOnIteration = false; +      MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration); + +      // Restart BB iteration if the dominator tree of the Function was changed +      if (ModifiedDTOnIteration) +        break; +    } + +    EverMadeChange |= MadeChange; +  } + +  return EverMadeChange; +} + +bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) { +  bool MadeChange = false; + +  BasicBlock::iterator CurInstIterator = BB.begin(); +  while (CurInstIterator != BB.end()) { +    if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) +      MadeChange |= optimizeCallInst(CI, ModifiedDT); +    if (ModifiedDT) +      return true; +  } + +  return MadeChange; +} + +bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI, +                                                bool &ModifiedDT) { +  IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); +  if (II) { +    switch (II->getIntrinsicID()) { +    default: +      break; +    case Intrinsic::masked_load: { +      // Scalarize unsupported vector masked load +      unsigned Alignment = +        cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue(); +      if (TTI->isLegalMaskedLoad(CI->getType(), MaybeAlign(Alignment))) +        return false; +      scalarizeMaskedLoad(CI, ModifiedDT); +      return true; +    } +    case Intrinsic::masked_store: { +      unsigned Alignment = +        cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue(); +      if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType(), +                                  MaybeAlign(Alignment))) +        return false; +      scalarizeMaskedStore(CI, ModifiedDT); +      return true; +    } +    case Intrinsic::masked_gather: +      if (TTI->isLegalMaskedGather(CI->getType())) +        return false; +      scalarizeMaskedGather(CI, ModifiedDT); +      return true; +    case Intrinsic::masked_scatter: +      if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) +        return false; +      scalarizeMaskedScatter(CI, ModifiedDT); +      return true; +    case Intrinsic::masked_expandload: +      if (TTI->isLegalMaskedExpandLoad(CI->getType())) +        return false; +      scalarizeMaskedExpandLoad(CI, ModifiedDT); +      return true; +    case Intrinsic::masked_compressstore: +      if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) +        return false; +      scalarizeMaskedCompressStore(CI, ModifiedDT); +      return true; +    } +  } + +  return false; +}  | 
