diff options
Diffstat (limited to 'lib/CodeGen/ScalarizeMaskedMemIntrin.cpp')
-rw-r--r-- | lib/CodeGen/ScalarizeMaskedMemIntrin.cpp | 306 |
1 files changed, 225 insertions, 81 deletions
diff --git a/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp b/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp index 2684f92b3a93..7776dffb4e9c 100644 --- a/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp +++ b/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp @@ -1,10 +1,9 @@ //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// // instrinsics // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -124,7 +123,7 @@ static bool isConstantIntVector(Value *Mask) { // %10 = extractelement <16 x i1> %mask, i32 2 // br i1 %10, label %cond.load4, label %else5 // -static void scalarizeMaskedLoad(CallInst *CI) { +static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) { Value *Ptr = CI->getArgOperand(0); Value *Alignment = CI->getArgOperand(1); Value *Mask = CI->getArgOperand(2); @@ -144,7 +143,7 @@ static void scalarizeMaskedLoad(CallInst *CI) { // Short-cut if the mask is all-true. if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { - Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal); + Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); CI->replaceAllUsesWith(NewI); CI->eraseFromParent(); return; @@ -152,9 +151,9 @@ static void scalarizeMaskedLoad(CallInst *CI) { // Adjust alignment for the scalar instruction. AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); - // Bitcast %addr fron i8* to EltTy* + // Bitcast %addr from i8* to EltTy* Type *NewPtrType = - EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace()); + EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); unsigned VectorWidth = VecType->getNumElements(); @@ -165,11 +164,9 @@ static void scalarizeMaskedLoad(CallInst *CI) { for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) continue; - Value *Gep = - Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx)); - LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal); - VResult = - Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx)); + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal); + VResult = Builder.CreateInsertElement(VResult, Load, Idx); } CI->replaceAllUsesWith(VResult); CI->eraseFromParent(); @@ -184,8 +181,7 @@ static void scalarizeMaskedLoad(CallInst *CI) { // br i1 %mask_1, label %cond.load, label %else // - Value *Predicate = - Builder.CreateExtractElement(Mask, Builder.getInt32(Idx)); + Value *Predicate = Builder.CreateExtractElement(Mask, Idx); // Create "cond" block // @@ -197,11 +193,9 @@ static void scalarizeMaskedLoad(CallInst *CI) { "cond.load"); Builder.SetInsertPoint(InsertPt); - Value *Gep = - Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx)); - LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal); - Value *NewVResult = Builder.CreateInsertElement(VResult, Load, - Builder.getInt32(Idx)); + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal); + Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); // Create "else" block, fill it in the next iteration BasicBlock *NewIfBlock = @@ -222,6 +216,8 @@ static void scalarizeMaskedLoad(CallInst *CI) { CI->replaceAllUsesWith(VResult); CI->eraseFromParent(); + + ModifiedDT = true; } // Translate a masked store intrinsic, like @@ -250,7 +246,7 @@ static void scalarizeMaskedLoad(CallInst *CI) { // store i32 %6, i32* %7 // br label %else2 // . . . -static void scalarizeMaskedStore(CallInst *CI) { +static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) { Value *Src = CI->getArgOperand(0); Value *Ptr = CI->getArgOperand(1); Value *Alignment = CI->getArgOperand(2); @@ -276,9 +272,9 @@ static void scalarizeMaskedStore(CallInst *CI) { // Adjust alignment for the scalar instruction. AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); - // Bitcast %addr fron i8* to EltTy* + // Bitcast %addr from i8* to EltTy* Type *NewPtrType = - EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace()); + EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); unsigned VectorWidth = VecType->getNumElements(); @@ -286,9 +282,8 @@ static void scalarizeMaskedStore(CallInst *CI) { for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) continue; - Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx)); - Value *Gep = - Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx)); + Value *OneElt = Builder.CreateExtractElement(Src, Idx); + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); Builder.CreateAlignedStore(OneElt, Gep, AlignVal); } CI->eraseFromParent(); @@ -301,8 +296,7 @@ static void scalarizeMaskedStore(CallInst *CI) { // %mask_1 = extractelement <16 x i1> %mask, i32 Idx // br i1 %mask_1, label %cond.store, label %else // - Value *Predicate = - Builder.CreateExtractElement(Mask, Builder.getInt32(Idx)); + Value *Predicate = Builder.CreateExtractElement(Mask, Idx); // Create "cond" block // @@ -314,9 +308,8 @@ static void scalarizeMaskedStore(CallInst *CI) { IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); Builder.SetInsertPoint(InsertPt); - Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx)); - Value *Gep = - Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx)); + Value *OneElt = Builder.CreateExtractElement(Src, Idx); + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); Builder.CreateAlignedStore(OneElt, Gep, AlignVal); // Create "else" block, fill it in the next iteration @@ -329,6 +322,8 @@ static void scalarizeMaskedStore(CallInst *CI) { IfBlock = NewIfBlock; } CI->eraseFromParent(); + + ModifiedDT = true; } // Translate a masked gather intrinsic like @@ -360,13 +355,14 @@ static void scalarizeMaskedStore(CallInst *CI) { // . . . // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src // ret <16 x i32> %Result -static void scalarizeMaskedGather(CallInst *CI) { +static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) { Value *Ptrs = CI->getArgOperand(0); Value *Alignment = CI->getArgOperand(1); Value *Mask = CI->getArgOperand(2); Value *Src0 = CI->getArgOperand(3); VectorType *VecType = cast<VectorType>(CI->getType()); + Type *EltTy = VecType->getElementType(); IRBuilder<> Builder(CI->getContext()); Instruction *InsertPt = CI; @@ -385,12 +381,11 @@ static void scalarizeMaskedGather(CallInst *CI) { for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) continue; - Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx), - "Ptr" + Twine(Idx)); + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); LoadInst *Load = - Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx)); - VResult = Builder.CreateInsertElement( - VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx)); + Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); + VResult = + Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); } CI->replaceAllUsesWith(VResult); CI->eraseFromParent(); @@ -404,8 +399,8 @@ static void scalarizeMaskedGather(CallInst *CI) { // br i1 %Mask1, label %cond.load, label %else // - Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx), - "Mask" + Twine(Idx)); + Value *Predicate = + Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); // Create "cond" block // @@ -416,13 +411,11 @@ static void scalarizeMaskedGather(CallInst *CI) { BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load"); Builder.SetInsertPoint(InsertPt); - Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx), - "Ptr" + Twine(Idx)); + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); LoadInst *Load = - Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx)); - Value *NewVResult = Builder.CreateInsertElement(VResult, Load, - Builder.getInt32(Idx), - "Res" + Twine(Idx)); + Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); + Value *NewVResult = + Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); // Create "else" block, fill it in the next iteration BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); @@ -441,6 +434,8 @@ static void scalarizeMaskedGather(CallInst *CI) { CI->replaceAllUsesWith(VResult); CI->eraseFromParent(); + + ModifiedDT = true; } // Translate a masked scatter intrinsic, like @@ -469,7 +464,7 @@ static void scalarizeMaskedGather(CallInst *CI) { // store i32 %Elt1, i32* %Ptr1, align 4 // br label %else2 // . . . -static void scalarizeMaskedScatter(CallInst *CI) { +static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) { Value *Src = CI->getArgOperand(0); Value *Ptrs = CI->getArgOperand(1); Value *Alignment = CI->getArgOperand(2); @@ -493,12 +488,11 @@ static void scalarizeMaskedScatter(CallInst *CI) { // Shorten the way if the mask is a vector of constants. if (isConstantIntVector(Mask)) { for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - if (cast<ConstantVector>(Mask)->getAggregateElement(Idx)->isNullValue()) + if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) continue; - Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx), - "Elt" + Twine(Idx)); - Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx), - "Ptr" + Twine(Idx)); + Value *OneElt = + Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); } CI->eraseFromParent(); @@ -511,8 +505,8 @@ static void scalarizeMaskedScatter(CallInst *CI) { // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx // br i1 %Mask1, label %cond.store, label %else // - Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx), - "Mask" + Twine(Idx)); + Value *Predicate = + Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); // Create "cond" block // @@ -523,10 +517,8 @@ static void scalarizeMaskedScatter(CallInst *CI) { BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store"); Builder.SetInsertPoint(InsertPt); - Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx), - "Elt" + Twine(Idx)); - Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx), - "Ptr" + Twine(Idx)); + Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); // Create "else" block, fill it in the next iteration @@ -538,6 +530,156 @@ static void scalarizeMaskedScatter(CallInst *CI) { IfBlock = NewIfBlock; } CI->eraseFromParent(); + + ModifiedDT = true; +} + +static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) { + Value *Ptr = CI->getArgOperand(0); + Value *Mask = CI->getArgOperand(1); + Value *PassThru = CI->getArgOperand(2); + + VectorType *VecType = cast<VectorType>(CI->getType()); + + Type *EltTy = VecType->getElementType(); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + + Builder.SetInsertPoint(InsertPt); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + unsigned VectorWidth = VecType->getNumElements(); + + // The result vector + Value *VResult = PassThru; + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] + // %mask_1 = extractelement <16 x i1> %mask, i32 Idx + // br i1 %mask_1, label %cond.load, label %else + // + + Value *Predicate = + Builder.CreateExtractElement(Mask, Idx); + + // Create "cond" block + // + // %EltAddr = getelementptr i32* %1, i32 0 + // %Elt = load i32* %EltAddr + // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx + // + BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), + "cond.load"); + Builder.SetInsertPoint(InsertPt); + + LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1); + Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); + + // Move the pointer if there are more blocks to come. + Value *NewPtr; + if ((Idx + 1) != VectorWidth) + NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = + CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + BasicBlock *PrevIfBlock = IfBlock; + IfBlock = NewIfBlock; + + // Create the phi to join the new and previous value. + PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); + ResultPhi->addIncoming(NewVResult, CondBlock); + ResultPhi->addIncoming(VResult, PrevIfBlock); + VResult = ResultPhi; + + // Add a PHI for the pointer if this isn't the last iteration. + if ((Idx + 1) != VectorWidth) { + PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); + PtrPhi->addIncoming(NewPtr, CondBlock); + PtrPhi->addIncoming(Ptr, PrevIfBlock); + Ptr = PtrPhi; + } + } + + CI->replaceAllUsesWith(VResult); + CI->eraseFromParent(); + + ModifiedDT = true; +} + +static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) { + Value *Src = CI->getArgOperand(0); + Value *Ptr = CI->getArgOperand(1); + Value *Mask = CI->getArgOperand(2); + + VectorType *VecType = cast<VectorType>(Src->getType()); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + + Builder.SetInsertPoint(InsertPt); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + Type *EltTy = VecType->getVectorElementType(); + + unsigned VectorWidth = VecType->getNumElements(); + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %mask_1 = extractelement <16 x i1> %mask, i32 Idx + // br i1 %mask_1, label %cond.store, label %else + // + Value *Predicate = Builder.CreateExtractElement(Mask, Idx); + + // Create "cond" block + // + // %OneElt = extractelement <16 x i32> %Src, i32 Idx + // %EltAddr = getelementptr i32* %1, i32 0 + // %store i32 %OneElt, i32* %EltAddr + // + BasicBlock *CondBlock = + IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); + Builder.SetInsertPoint(InsertPt); + + Value *OneElt = Builder.CreateExtractElement(Src, Idx); + Builder.CreateAlignedStore(OneElt, Ptr, 1); + + // Move the pointer if there are more blocks to come. + Value *NewPtr; + if ((Idx + 1) != VectorWidth) + NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = + CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + BasicBlock *PrevIfBlock = IfBlock; + IfBlock = NewIfBlock; + + // Add a PHI for the pointer if this isn't the last iteration. + if ((Idx + 1) != VectorWidth) { + PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); + PtrPhi->addIncoming(NewPtr, CondBlock); + PtrPhi->addIncoming(Ptr, PrevIfBlock); + Ptr = PtrPhi; + } + } + CI->eraseFromParent(); + + ModifiedDT = true; } bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) { @@ -587,33 +729,35 @@ bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI, break; case Intrinsic::masked_load: // Scalarize unsupported vector masked load - if (!TTI->isLegalMaskedLoad(CI->getType())) { - scalarizeMaskedLoad(CI); - ModifiedDT = true; - return true; - } - return false; + if (TTI->isLegalMaskedLoad(CI->getType())) + return false; + scalarizeMaskedLoad(CI, ModifiedDT); + return true; case Intrinsic::masked_store: - if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) { - scalarizeMaskedStore(CI); - ModifiedDT = true; - return true; - } - return false; + if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) + return false; + scalarizeMaskedStore(CI, ModifiedDT); + return true; case Intrinsic::masked_gather: - if (!TTI->isLegalMaskedGather(CI->getType())) { - scalarizeMaskedGather(CI); - ModifiedDT = true; - return true; - } - return false; + if (TTI->isLegalMaskedGather(CI->getType())) + return false; + scalarizeMaskedGather(CI, ModifiedDT); + return true; case Intrinsic::masked_scatter: - if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) { - scalarizeMaskedScatter(CI); - ModifiedDT = true; - return true; - } - return false; + if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) + return false; + scalarizeMaskedScatter(CI, ModifiedDT); + return true; + case Intrinsic::masked_expandload: + if (TTI->isLegalMaskedExpandLoad(CI->getType())) + return false; + scalarizeMaskedExpandLoad(CI, ModifiedDT); + return true; + case Intrinsic::masked_compressstore: + if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) + return false; + scalarizeMaskedCompressStore(CI, ModifiedDT); + return true; } } |